diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index e04eac5e4f604b..bb254f0c69c5aa 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,7 +1,7 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt -@@ -13,135 +13,20 @@ +@@ -13,131 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # @@ -134,10 +134,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt -#------------------------------------------------------------------------------- - -if(STABLEHLO_ENABLE_BINDINGS_PYTHON) -- if(NOT STABLEHLO_EXTERNAL_PROJECT_BUILD) -- message(WARNING "StableHLO Python bindings are not supported in standalone mode") -- endif() -- - include(MLIRDetectPythonEnv) - mlir_configure_python_dev_packages() -endif() @@ -145,26 +141,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup -diff --ruN a/stablehlo/docs/_toc.yaml b/stablehlo/docs/_toc.yaml ---- stablehlo/docs/_toc.yaml -+++ stablehlo/docs/_toc.yaml -@@ -1,3 +1,16 @@ -+# Copyright 2023 The StableHLO Authors. -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - toc: - - heading: StableHLO developer guide - - title: Overview diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists.txt --- stablehlo/stablehlo/CMakeLists.txt +++ stablehlo/stablehlo/CMakeLists.txt @@ -176,18 +152,6 @@ diff --ruN a/stablehlo/stablehlo/CMakeLists.txt b/stablehlo/stablehlo/CMakeLists add_subdirectory(integrations) add_subdirectory(reference) add_subdirectory(tests) -diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir b/stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir ---- stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir -+++ stablehlo/stablehlo/conversions/linalg/tests/reduce.mlir -@@ -29,7 +29,7 @@ - // CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor - // CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() - // CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] --// CHECK-PRIMITIVE: linalg.reduce { arith.addi } -+// CHECK-PRIMITIVE: linalg.reduce { arith.addi {overflowFlags = #arith.overflow} } - // CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) - // CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) - // CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr} diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel --- stablehlo/stablehlo/experimental/BUILD.bazel +++ stablehlo/stablehlo/experimental/BUILD.bazel @@ -2756,83 +2720,4 @@ diff --ruN a/stablehlo/stablehlo/experimental/transforms/StablehloRefineShapes.c +} // namespace experimental +} // namespace stablehlo +} // namespace mlir -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -967,16 +967,16 @@ - // better error reporting for this case. - // This serves the current use cases well, so the implementation of more - // sophisticated refinement algorithm is left for future work. -- rewriter.startRootUpdate(op); -+ rewriter.startOpModification(op); - auto condStatus = refineValues(rewriter, op, op.getCond().getArguments(), - op.getOperandTypes()); - auto bodyStatus = refineValues(rewriter, op, op.getBody().getArguments(), - op.getOperandTypes()); - if (succeeded(condStatus) || succeeded(bodyStatus)) { -- rewriter.finalizeRootUpdate(op); -+ rewriter.finalizeOpModification(op); - return success(); - } else { -- rewriter.cancelRootUpdate(op); -+ rewriter.cancelOpModification(op); - return failure(); - } - } -@@ -1055,7 +1055,7 @@ - if (!needsUpdate) - return rewriter.notifyMatchFailure(op, "doesn't need update"); - -- rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); -+ rewriter.modifyOpInPlace(op->getParentOp(), [&]() { return; }); - return success(); - } - }; -diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -@@ -426,15 +426,19 @@ - return success(); - } - --template --SpecialResult convertDenseArray(StringAttr vhloName, Attribute vhloAttr, -- SmallVector& stablehloAttrs) { -+SpecialResult convertDenseI64Array( -+ StringAttr vhloName, Attribute vhloAttr, -+ SmallVector& stablehloAttrs) { - auto tensorAttr = dyn_cast(vhloAttr); - if (!tensorAttr) return specialFailure(); - -- auto data = SmallVector( -- ArrayRef(reinterpret_cast(tensorAttr.getData().data()), -- tensorAttr.getData().size() / sizeof(T))); -+ if (tensorAttr.getData().size() % sizeof(int64_t) != 0) -+ return specialFailure(); -+ -+ auto data = ArrayRef( -+ reinterpret_cast(tensorAttr.getData().data()), -+ tensorAttr.getData().size() / sizeof(int64_t)) -+ .vec(); - - // Handle splats - if (data.size() == 1) { -@@ -445,15 +449,9 @@ - data.resize(size, data[0]); - } - -- stablehloAttrs.emplace_back(vhloName, Attr::get(vhloAttr.getContext(), data)); -+ stablehloAttrs.emplace_back( -+ vhloName, DenseI64ArrayAttr::get(vhloAttr.getContext(), data)); - return specialSuccess(); --} -- --SpecialResult convertDenseI64Array( -- StringAttr vhloName, Attribute vhloAttr, -- SmallVector& stablehloAttrs) { -- return convertDenseArray(vhloName, vhloAttr, -- stablehloAttrs); - } - - template diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index d996dea73a1be2..cf0ae34f4e54b3 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "9c8a1b7ddd5e4ef5abbdc7ef33b041a56343ae2f" - STABLEHLO_SHA256 = "4f6b6d4e5a96893a5ee6ff53aa017b7e151690f09c8c5eac77bd19b6330faa7a" + STABLEHLO_COMMIT = "20255865ba299ed67bdf6267478c8477aef7a60d" + STABLEHLO_SHA256 = "d703dba8c3f6ed1b5c7ac9772ec24d474644a5f10c0122aaca1401e0b4120471" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index a0f14f2710eccd..5c4ab30e7fe021 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -2089,15 +2089,8 @@ LogicalResult AllReduceOp::inferReturnTypeComponents( } // Populate inferred return shapes - for (auto resultType : adaptor.getOperands().getTypes()) { - auto rankedResult = resultType.dyn_cast(); - if (rankedResult) - inferredReturnShapes.emplace_back(rankedResult.getShape(), - rankedResult.getElementType(), - rankedResult.getEncoding()); - else - inferredReturnShapes.emplace_back(resultType.cast()); - } + return hlo::inferAllReduceOp(location, adaptor.getOperands(), + adaptor.getComputation(), inferredReturnShapes); return success(); } @@ -3266,157 +3259,24 @@ OpFoldResult CopyOp::fold(FoldAdaptor) { return getOperand(); } // ReduceWindowOp //===----------------------------------------------------------------------===// -namespace { - -// TODO(@sdasgup): Reuse the same function from hlo namespace. -FailureOr> convert1DAttribute( - std::optional optionalAttr, - std::optional loc, StringRef attrName) { - if (!optionalAttr.has_value()) return SmallVector{}; - - DenseIntElementsAttr attr = *optionalAttr; - auto attrType = attr.getType().cast(); - if (attrType.getRank() != 1) - return emitOptionalError(loc, "expects the shape of ", attrName, - " attribute to be 1-D, but got {", - attrType.getShape(), "}."); - auto values = attr.getValues(); - return SmallVector{values.begin(), values.end()}; -} - -LogicalResult verifyReduceWindowOpInputsAndInferWindow( - std::optional location, SmallVector inputTypes, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, - std::optional padding, - SmallVector& windowDims, - SmallVector& inferredWindow) { - // reduce_window_c1 - if (inputTypes.empty()) - return emitOptionalError(location, "requires at least 1 input value"); - - // Check for unranked tensors in input operands. - uint64_t numInputs = inputTypes.size(); - int64_t rankedInputIdx = -1; - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - if (inputTypes[inputIdx].hasRank()) { - rankedInputIdx = inputIdx; - break; - } - } - bool allInputsUnranked = (rankedInputIdx == -1); - - // reduce_window_c2 - if (!allInputsUnranked) { - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) - if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx], - inputTypes[inputIdx]))) - return emitOptionalError( - location, "expects all inputs to have compatible shapes. Shape at", - " input-index ", inputIdx, - " is not compatible with shape at input-index ", rankedInputIdx); - } - - // reduce_window_i3 - auto windowDimsOrErr = - convert1DAttribute(windowDimensions, location, "window_dimensions"); - if (failed(windowDimsOrErr)) return failure(); - // reduce_window_i4 - auto windowStridesOrErr = - convert1DAttribute(windowStrides, location, "window_strides"); - if (failed(windowStridesOrErr)) return failure(); - // reduce_window_i5 - auto baseDilationsOrErr = - convert1DAttribute(baseDilations, location, "base_dilations"); - if (failed(baseDilationsOrErr)) return failure(); - // reduce_window_i6 - auto windowDilationsOrErr = - convert1DAttribute(windowDilations, location, "window_dilations"); - if (failed(windowDilationsOrErr)) return failure(); - // reduce_window_c12, reduce_window_i7 - auto paddingOrErr = hlo::convertPaddingAttribute(padding, location); - if (failed(paddingOrErr)) return failure(); - - // reduce_window_c4 - for (const auto inputType : inputTypes) { - if (!inputType.hasRank()) continue; - if (inputType.getRank() != static_cast((*windowDimsOrErr).size())) - return emitOptionalError( - location, "expects window-dimensions size == input rank, but got ", - "window-dimensions size: ", (*windowDimsOrErr).size(), - " and input: ", inputType, " with rank = ", inputType.getRank(), "."); - } - - // reduce_window_c5...reduce_window_c12 - auto windowOrErr = hlo::verifyWindowAttributesAndInferWindowDimensions( - *windowDimsOrErr, *windowStridesOrErr, *paddingOrErr, - /*lhsDilation=*/*baseDilationsOrErr, - /*rhsDilation=*/*windowDilationsOrErr, /*windowReversal=*/std::nullopt, - location); - if (failed(windowOrErr)) return failure(); - - windowDims.append(*windowDimsOrErr); - inferredWindow.append(*windowOrErr); - return success(); -} - -LogicalResult inferReduceWindowOp( - std::optional location, ValueRange inputs, - DenseIntElementsAttr windowDimensions, - std::optional windowStrides, - std::optional baseDilations, - std::optional windowDilations, - std::optional padding, - SmallVectorImpl& inferredReturnShapes) { - SmallVector inputTypes{llvm::map_range( - inputs.getTypes(), [](Type t) { return t.cast(); })}; - - SmallVector windowDims; - SmallVector inferredWindow; - // reduce_window_c1, reduce_window_c2, reduce_window_c4...reduce_window_c12, - // reduce_window_i4...reduce_window_i7 - if (failed(verifyReduceWindowOpInputsAndInferWindow( - location, inputTypes, windowDimensions, windowStrides, baseDilations, - windowDilations, padding, windowDims, inferredWindow))) - return failure(); - - // reduce_window_c1, reduce_window_c14...reduce_window_c16 - for (size_t i = 0; i < inputTypes.size(); ++i) { - auto inputRankedType = inputs[i].getType().dyn_cast(); - if (!inputRankedType) { - inferredReturnShapes.emplace_back(inputTypes[i].getElementType()); - } else { - auto resultShape = - inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow); - auto inputBounds = hlo::encodingToBounds(inputRankedType.getEncoding()); - if (inputBounds.empty()) { - inferredReturnShapes.emplace_back(resultShape, - inputTypes[i].getElementType()); - } else { - auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow); - inferredReturnShapes.emplace_back( - resultShape, inputTypes[i].getElementType(), - hlo::boundsToEncoding(inputRankedType.getEncoding(), resultBounds)); - } - } - } - - return success(); -} - -} // namespace - LogicalResult ReduceWindowOp::inferReturnTypeComponents( MLIRContext*, std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { ReduceWindowOp::Adaptor adaptor(operands, attributes, {}, regions); - return inferReduceWindowOp( - location, adaptor.getInputs(), adaptor.getWindowDimensions(), - adaptor.getWindowStrides(), adaptor.getBaseDilations(), - adaptor.getWindowDilations(), adaptor.getPadding(), inferredReturnShapes); + return hlo::inferReduceWindowOp( + location, adaptor.getInputs(), adaptor.getInitValues(), + llvm::to_vector(adaptor.getWindowDimensions().getValues()), + adaptor.getWindowStrides() + ? llvm::to_vector(adaptor.getWindowStrides()->getValues()) + : ArrayRef{}, + adaptor.getBaseDilations() + ? llvm::to_vector(adaptor.getBaseDilations()->getValues()) + : ArrayRef{}, + adaptor.getWindowDilations() + ? llvm::to_vector(adaptor.getWindowDilations()->getValues()) + : ArrayRef{}, + adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes); } LogicalResult ReduceWindowOp::verify() { @@ -4043,112 +3903,60 @@ ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } -namespace { - -// TODO(@sdasgup): Reuse the same functions from hlo namespace. -LogicalResult verifyReduceOpInputsAndInferShape( - std::optional location, SmallVector inputTypes, - DenseIntElementsAttr dimensions, SmallVector& newDimensions, - Attribute& encoding) { - // reduce_i3 - if (dimensions.getType().getRank() != 1) - return emitOptionalError(location, "dimensions must be rank 1"); - - // Check for unranked tensors in input operands. - uint64_t numInputs = inputTypes.size(); - int64_t rankedInputIdx = -1; - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - if (inputTypes[inputIdx].hasRank()) { - rankedInputIdx = inputIdx; - break; - } - } - bool allInputsUnranked = (rankedInputIdx == -1); - // reduce_c1 - if (!allInputsUnranked) { - for (uint64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) - if (failed(mlir::verifyCompatibleShape(inputTypes[rankedInputIdx], - inputTypes[inputIdx]))) - return emitOptionalError( - location, "expects all inputs to have compatible shapes. Shape at", - " input-index ", inputIdx, - " is not compatible with shape at input-index ", rankedInputIdx); - } - - DenseSet dimensionsToReduceSet; - for (int64_t dimension : dimensions.getValues()) { - // reduce_c4 - if ((!allInputsUnranked && - dimension >= inputTypes[rankedInputIdx].getRank()) || - dimension < 0) - return emitOptionalError( - location, "Out-of-bounds dimension ", dimension, ", expected to be ", - allInputsUnranked - ? "> 0" - : "less than the input-tensor rank " + - std::to_string(inputTypes[rankedInputIdx].getRank())); - - // reduce_c5 - if (!dimensionsToReduceSet.insert(dimension).second) - return emitOptionalError(location, - "Duplicate reduction dimension: ", dimension); - } - - if (!allInputsUnranked) { - auto rankedInput = inputTypes[rankedInputIdx].cast(); - ArrayRef inputBounds = - hlo::encodingToBounds(rankedInput.getEncoding()); - SmallVector newBounds; - for (int inputIdx = 0; inputIdx < rankedInput.getRank(); ++inputIdx) { - if (!dimensionsToReduceSet.count(inputIdx)) { - newDimensions.push_back(rankedInput.getDimSize(inputIdx)); - if (!inputBounds.empty()) newBounds.push_back(inputBounds[inputIdx]); - } - } - - // Set encoding based on the bounds only if the bounds is not empty. - encoding = nullptr; - if (!newBounds.empty()) - encoding = hlo::boundsToEncoding(rankedInput.getEncoding(), newBounds); - } - return success(); +LogicalResult ReduceOp::inferReturnTypeComponents( + MLIRContext*, std::optional location, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + ReduceOp::Adaptor adaptor(operands, attributes, {}, regions); + return hlo::inferReduceOp( + location, adaptor.getInputs().getTypes(), + llvm::to_vector(adaptor.getDimensions().getValues()), + adaptor.getBody(), inferredReturnShapes); } -LogicalResult inferReduceOp( - std::optional location, TypeRange inputTypes, - DenseIntElementsAttr dimensions, - SmallVectorImpl& inferredReturnShapes) { - SmallVector inputArgTensorTypes{ - llvm::map_range(inputTypes, [](Type t) { return t.cast(); })}; +void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs, + ValueRange initValues, DenseIntElementsAttr dimensions, + TypeRange elementTypes) { + odsState.addOperands(inputs); + odsState.addOperands(initValues); + odsState.addAttribute(getDimensionsAttrName(odsState.name), dimensions); + (void)odsState.addRegion(); SmallVector newDimensions; Attribute encoding; - // reduce_c1, reduce_c4, reduce_c5, reduce_i3 - if (failed(verifyReduceOpInputsAndInferShape( - location, inputArgTensorTypes, dimensions, newDimensions, encoding))) - return failure(); - // reduce_c2, reduce_c3, reduce_c7 - for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) { - ShapedType inputType = inputArgTensorTypes[inputIdx]; - Type elementType = inputType.getElementType(); - if (inputType.hasRank()) - inferredReturnShapes.emplace_back(newDimensions, elementType, encoding); - else - inferredReturnShapes.emplace_back(elementType); - } - - return success(); -} - -} // namespace + ReduceOp::Adaptor adaptor( + odsState.operands, + odsState.attributes.getDictionary(odsState.getContext()), {}, + odsState.regions); -LogicalResult ReduceOp::inferReturnTypeComponents( - MLIRContext*, std::optional location, ValueShapeRange operands, - DictionaryAttr attributes, OpaqueProperties, RegionRange regions, - SmallVectorImpl& inferredReturnShapes) { - ReduceOp::Adaptor adaptor(operands, attributes, {}, regions); - return inferReduceOp(location, adaptor.getInputs().getTypes(), - adaptor.getDimensions(), inferredReturnShapes); + SmallVector inputArgTensorTypes{ + llvm::map_range(adaptor.getInputs().getTypes(), + [](Type t) { return t.cast(); })}; + SmallVector initValueTensorTypes{ + llvm::map_range(adaptor.getInitValues().getTypes(), + [](Type t) { return t.cast(); })}; + + if (succeeded(hlo::verifyReduceOpInputsAndInferShape( + odsState.location, inputArgTensorTypes, + llvm::to_vector(dimensions.getValues()), newDimensions, + encoding))) { + SmallVector inferredReturnTypes; + for (uint64_t inputIdx = 0; inputIdx < inputArgTensorTypes.size(); + ++inputIdx) { + Type elementTy = elementTypes[inputIdx]; + ShapedType inputType = inputArgTensorTypes[inputIdx]; + if (inputType.hasRank()) { + inferredReturnTypes.push_back( + RankedTensorType::get(newDimensions, elementTy, encoding)); + } else { + assert(encoding == nullptr && "attribute not supported"); + inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); + } + } + odsState.addTypes(inferredReturnTypes); + } else { + llvm::report_fatal_error("Failed to infer result type(s)."); + } } LogicalResult ReduceOp::verify() { @@ -6146,24 +5954,14 @@ OpFoldResult CompareOp::fold(FoldAdaptor adaptor) { // SelectAndScatterOp //===----------------------------------------------------------------------===// -namespace { - -// TODO(@sdasgup): Reuse the same function from hlo namespace. -LogicalResult inferSelectAndScatterOp( - Value operand, SmallVectorImpl& inferredReturnTypes) { - // select_and_scatter_c11 - inferredReturnTypes.push_back(operand.getType()); - return success(); -} - -} // namespace - LogicalResult SelectAndScatterOp::inferReturnTypes( - MLIRContext*, std::optional, ValueRange operands, + MLIRContext*, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { SelectAndScatterOp::Adaptor adaptor(operands, attributes, {}, regions); - return inferSelectAndScatterOp(adaptor.getOperand(), inferredReturnTypes); + return hlo::inferSelectAndScatterOp(location, adaptor.getOperand(), + adaptor.getScatter(), + inferredReturnTypes); } LogicalResult SelectAndScatterOp::verify() { @@ -6190,23 +5988,14 @@ LogicalResult SelectAndScatterOp::verify() { // ScatterOp //===----------------------------------------------------------------------===// -namespace { - -// TODO(@sdasgup): Reuse the same function from hlo namespace. -LogicalResult inferScatterOp(std::optional, ValueRange inputs, - SmallVectorImpl& inferredReturnTypes) { - llvm::append_range(inferredReturnTypes, inputs.getTypes()); - return success(); -} - -} // namespace - LogicalResult ScatterOp::inferReturnTypes( MLIRContext*, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { ScatterOp::Adaptor adaptor(operands, attributes, {}, regions); - return inferScatterOp(location, adaptor.getInputs(), inferredReturnTypes); + return hlo::inferScatterOp(location, adaptor.getInputs(), + adaptor.getUpdateComputation(), + inferredReturnTypes); } LogicalResult ScatterOp::verify() { diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index 6ded5e0b34d970..307b4652559aba 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -1490,7 +1490,6 @@ def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> } def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [ - SameOperandsAndResultElementType, SingleBlockImplicitTerminator<"ReturnOp">, InferTensorType ]> { @@ -1544,8 +1543,7 @@ def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [ let hasCustomHLOConverter = 1; } -def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", - [SameOperandsAndResultElementType]> { +def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", []> { let summary = "ReduceScatter operation"; let description = [{ Within each process group in the process grid, performs reduction, using @@ -1691,6 +1689,12 @@ def MHLO_ReduceOp: MHLO_ShapedInterfaceOp<"reduce", [ // compatible with reduce op's operands. let regions = (region SizedRegion<1>:$body); + // Builder + let builders = [ + OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values, + "DenseIntElementsAttr":$dimensions, "TypeRange":$element_types)>, + ]; + // TODO(b/129422361): ReduceOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } diff --git a/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc b/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc index 61522d609910e5..d2058c0e23254f 100644 --- a/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc +++ b/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -224,8 +225,11 @@ LogicalResult tryLowerTo1DOr2DReduction( int64_t reductionDim = leadingReduction ? 0 : 1; auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim}); Value initVal = op.getInitValues().front(); - auto reductionOp = - rewriter.create(loc, intermResult, initVal, reductionDimAttr); + SmallVector elementTypes{llvm::map_range( + op.getBody().front().getTerminator()->getOperands(), + [](Value v) { return v.getType().cast().getElementType(); })}; + auto reductionOp = rewriter.create(loc, intermResult, initVal, + reductionDimAttr, elementTypes); rewriter.inlineRegionBefore(op.getBody(), reductionOp.getBody(), reductionOp.getBody().begin()); intermResult = reductionOp->getResults().front(); diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc index c71ec00eccde50..bb68ec22fbac90 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -219,8 +220,12 @@ struct HloCanonicalizeReductionPass elemTy), operand, newOperandShape)); } - auto newOp = - b.create(loc, newOperands, op.getInitValues(), attr); + SmallVector elementTypes{llvm::map_range( + op.getBody().front().getTerminator()->getOperands(), [](Value v) { + return v.getType().cast().getElementType(); + })}; + auto newOp = b.create(loc, newOperands, op.getInitValues(), + attr, elementTypes); newOp.getBody().takeBody(op.getBody()); SmallVector newResults; diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 5db74efa72f529..fd2e0bb9dc3225 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -53,6 +53,41 @@ func.func @all_reduce_tuple(%arg0: tensor<10xf32>, %arg1: tensor) -> tensor // ----- +// CHECK-LABEL: func @all_reduce_with_promotable_types +func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor { + + %result = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "mhlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_quantized_types +func.func @all_reduce_with_promotable_quantized_types(%operand: tensor>) + -> tensor> { + + %result = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = mhlo.add %arg0, %arg1 : tensor> + "mhlo.return"(%0) : (tensor>) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor>) -> tensor> + + func.return %result : tensor> +} + +// ----- + func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { // expected-error@+2 {{'mhlo.all_reduce' op failed to infer returned types}} // expected-error@+1 {{Reduction-region must take 2 parameters, but takes 3 parameter(s)}} @@ -200,7 +235,8 @@ func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10 // ----- func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10xi32> { - // expected-error@+1 {{'mhlo.all_reduce' op requires the same element type for all operands and results}} + // expected-error@+2 {{'mhlo.all_reduce' op inferred type(s) 'tensor<10xf32>' are incompatible with return type(s) of operation 'tensor<10xi32>'}} + // expected-error@+1 {{'mhlo.all_reduce' op failed to infer returned types}} %0 = "mhlo.all_reduce"(%operand) ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = mhlo.maximum %arg0, %arg1 : tensor @@ -309,6 +345,38 @@ func.func @reduce_scatter_dynamic(%data: tensor) -> tensor { // ----- +// CHECK-LABEL: func @reduce_scatter_with_promotable_types +func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types +func.func @reduce_scatter_with_promotable_quantized_types( + %data: tensor<4x16x!quant.uniform>) -> + tensor<4x4x!quant.uniform> { + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = mhlo.add %arg2, %arg3 : tensor> + "mhlo.return"(%1) : (tensor>) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #mhlo.channel_handle, + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> +} + +// ----- + func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{expects scatter_dimension >= 0}} %0 = "mhlo.reduce_scatter"(%data) ({ diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir index 5087451dcb3272..baab3c2bb83678 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir @@ -94,6 +94,35 @@ func.func @reduce_mix_rank_and_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*x func.return %0#0, %0#1 : tensor<4xf32>, tensor<*xf32> } +// ----- + +// CHECK-LABEL: func @reduce_with_promotable_types +func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_with_promotable_quantized_types +func.func @reduce_with_promotable_quantized_types(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + %0 = mhlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = mhlo.add %arg2, %arg3 : tensor> + mhlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + // Next, we have the invalid testcases. // ----- diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir index 91d5a3866ccfe1..cc2e54bba84715 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_window_op.mlir @@ -79,6 +79,46 @@ func.func @reduce_window_with_non_scalar_block_arg2(%arg0: tensor<4x2xf32>, // ----- +// CHECK-LABEL: func @reduce_window_with_promotable_types +func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + %0:2 = "mhlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = mhlo.add %a0, %b0 : tensor + %3 = mhlo.add %a1, %b1 : tensor + "mhlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @reduce_window_with_promotable_quantized_types +func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + %0 = "mhlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = mhlo.add %a0, %b0 : tensor> + "mhlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_invalid_inputs(%arg0: tensor<4x2xf32>, %arg1: tensor<4x3xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir index fc484c591319e3..a2b4a488b2b305 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir @@ -46,6 +46,54 @@ func.func @scatter_with_unranked_inputs(%input_tensor: tensor<*xf32>, // ----- +// CHECK: func @scatter_with_promotable_types +func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// ----- + +// CHECK: func @scatter_with_promotable_quantized_types +func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = mhlo.add %lhs, %rhs : tensor> + "mhlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, + tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} +// ----- + func.func @invalid_scatter(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xf32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir index acc5cb804df900..c608b30f540d68 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir @@ -26,6 +26,61 @@ func.func @select_and_scatter( func.return %1 : tensor<10x24x24x64xf32> } + +// CHECK: func @select_and_scatter_with_promotable_types +func.func @select_and_scatter_with_promotable_types( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "mhlo.compare"(%arg3, %arg4) { + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + + +// CHECK: func @select_and_scatter_with_promotable_quantized_types +func.func @select_and_scatter_with_promotable_quantized_types( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "mhlo.compare"(%arg3, %arg4) { + compare_type = #mhlo, + comparison_direction = #mhlo + } : (tensor>, tensor>) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = mhlo.add %arg3, %arg4 : tensor> + "mhlo.return"(%2) : (tensor>) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> +} + +// CHECK: func @select_and_scatter_with_unranked_dims func.func @select_and_scatter_with_unranked_dims( %arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>,