diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index d6552314999b..c2e757f7a0ff 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(TorchOnnxToTorch) + set(LLVM_TARGET_DEFINITIONS Passes.td) if(TORCH_MLIR_ENABLE_STABLEHLO) mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..a58ce5bf9b7d --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen) +add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h new file mode 100644 index 000000000000..6eea35c9d255 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h @@ -0,0 +1,27 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir::torch::onnx_c { + +std::unique_ptr> createTorchOnnxToTorchPass(); + +/// Registers all torch-mlir conversion passes. +void registerTorchOnnxToTorchPasses(); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td new file mode 100644 index 000000000000..b92649d025a6 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> { + let summary = "Converts ONNX custom ops in the torch dialect to native torch ops"; + let description = [{ + Converts equivalent ONNX custom ops to built-in equivalents. + + See the README for a detailed description of how this operates. + }]; + + let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()"; +} + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h new file mode 100644 index 000000000000..4c8d73a48116 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -0,0 +1,169 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir::torch::onnx_c { + +/// Used during ONNX pattern matching to bind common patterns of operands, +/// result types and attributes to local variables in a way that is easy +/// to fail the pattern if constraints are violated. Most methods return +/// a ParseResult, which allows for chaining like: +/// +/// if (binder.tensorOperand(foo) || binder.tensorResultType(t)) +/// return failure(); +struct OpBinder { + OpBinder(Operation *op) : op(op) {} + + Location getLoc() { return op->getLoc(); } + + // Operand matches of different arities. + ParseResult tensorOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + if (!toValidTensorType(value0.getType())) + return failure(); + return success(); + } + + ParseResult tensorOperands(Value &value0, Value &value1) { + if (op->getNumOperands() != 2) + return failure(); + value0 = op->getOperand(0); + value1 = op->getOperand(1); + if (!toValidTensorType(value0.getType()) || + !toValidTensorType(value1.getType())) + return failure(); + return success(); + } + + // Result type matchers of different arities. + ParseResult tensorResultType(Torch::ValueTensorType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto t = toValidTensorType(op->getResult(0).getType()); + if (!t) + return failure(); + type0 = t; + return success(); + } + + // Attribute accessors. + ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, + bool defaultValue = false) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = static_cast(integerAttr.getSInt()); + return success(); + } + return failure(); + } + + ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix, + int64_t defaultValue = 0) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto integerAttr = dyn_cast(attr)) { + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + value = integerAttr.getSInt(); + return success(); + } + return failure(); + } + + Torch::ValueTensorType toValidTensorType(Type t) { + auto tt = dyn_cast(t); + if (tt && tt.hasSizes()) + return tt; + return {}; + } + + Operation *op; +}; + +/// We use a single pattern per ONNX domain to handle all named custom +/// ops. +/// This allows us to avoid the n^2 problem on pattern application by +/// implementing a secondary index based on the name and sinceVersion +/// attributes. +/// It also lets us add some ergonomics for trivial cases. +class OnnxCustomOpConversionPattern + : public OpConversionPattern { +public: + using HandlerFn = LogicalResult (*)(OpBinder binder, + ConversionPatternRewriter &rewriter); + struct HandlerReg { + HandlerReg(HandlerFn callback, int64_t sinceVersion) + : callback(callback), sinceVersion(sinceVersion) {} + HandlerFn callback; + int64_t sinceVersion; + }; + + OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, + int64_t domainVersion) + : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), + domainVersion(domainVersion) {} + + LogicalResult + matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + /// Adds all fully qualified operator names to the given set. + /// This is typically used for implementing a dynamic legality + /// check for torch.operator names. + void populateLegalizedNames(DenseSet &legalizedNames); + + /// Register a conversion for a specific ONNX operator. For the + /// default domain, this is the canonical ONNX operator name (i.e. + /// "Acos"). + /// Multiple conversions can be registered for the same op, most + /// commonly differing by their `sinceVersion`. + void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback); + +private: + std::string domainPrefix; + int64_t domainVersion; + DenseMap> namedHandlers; +}; + +// Patterns are split into chunks to speed compile time and reduce some +// contention on the same source files. +void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns); +void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md new file mode 100644 index 000000000000..6de1cc923411 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md @@ -0,0 +1,133 @@ +# TorchOnnx To Torch Conversions + +We enable the direct representation of many ONNX features directly in +the `torch` dialect as `torch.operator` custom ops with names like +`onnx.{OperatorName}`. The majority of ONNX operators are represented +with a systematic transformation. See +[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) +for the reference importer which complies with the rules below +(this is planned to be upstreamed to torch-mlir proper in the near +future). + +## Adding new ONNX operators + +With the exception of certain special or complicated ONNX operators, most +are relatively straight-forward to map, following this general procedure: + +* Plan the ops you wish to support by consulting the + [ONNX operator database](https://onnx.ai/onnx/operators/). + * This database has detailed diffs wrt different support versions but + at the level of detail we operate, most version diffs are inconsequential + and just require a bit more pattern support. + * This typically applies to generalization of broadcasting semantics, + expanded type support, and other things of the like. +* *Prerequisite*: Add support for the op to torch-mlir if it does not + already exist. +* Open the corresponding implementation file `DefaultDomainXtoY.cpp` + corresponding with the alphabetic sort of the op and add a conversion. +* Generate successful test cases: + * Either run the Turbine importer to produce MLIR output for all + ops/models in the ONNX test suite or use a dump that someone has + generated: + * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * There are often many variants of tests for checking conformance of + different historic ONNX encodings, but these are often not load bearing + at the MLIR level. + * Pick a handful of test cases and add them to + `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an + alphabetic breakdown. At this time, ignore tests that are not exercising + useful differences in the pattern implementations. +* Generate failure test cases: + * Some ops have forms that do not (easily) map to torch-mlir. If you leave + an op under-implemented, add a failing test case to + `test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. +* Optional but recommended: Use your test case files to fuzz against the + torch-mlir backend of your choice by running a backend conversion pipeline + and fixing any crashes/issues. +* Send a patch with your changes. + +## ONNX proto to `torch` dialect mapping + +### Type Conversion + +* Tensors: ONNX tensor types are converted to `torch.vtensor` + with static and dynamic dimensions. We require that shape + inference has run to produce ranked tensors. +* Tensor element types are directly converted to corresponding + MLIR types as used by the rest of torch-mlir. +* String, sequence and sparse tensor types are presently not mapped. + +### Attributes + +A subset of attributes types are converted directly to an attribute +dict on the op with a name like `torch.onnx.{AttributeName}`. The +following attribute type mappings are made: + +* `FLOAT`: `FloatAttr` +* `INT`: Signed `IntegerAttr` of width 64 +* `STRING`: `StringAttr` +* `TENSOR`: Converted to one of: + * `DenseResourceElementsAttr` for inlined `raw_data` + * `DenseElementsAttr` for splats + * `DenseElementsAttr` for inlined typed proto initialization +* `FLOATS`: `ArrayAttr` of `FloatAttr` +* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64 +* `STRINGS`: `ArrayAttr` of `StringAttr` +* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion + +The following attribute types have no present, systematic conversion. +Their presence on an op indicates that the op is a special form, which +must be handled specially: + +* `GRAPH` +* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if + useful). +* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if + useful). +* Plural equivalents of the above. + +### Default operation conversion + +Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`. +The constraint that the ONNX graph is topologically sorted and free of +cycles matches the SSA form. Operands and results are mapped directly. + +This conversion only applies to the default (empty) domain. + +### Quantization information + +Quantization parameters are carried out of line in the ONNX protobuf +and will be repatriated upon import to torch. The exact mechanism is +not yet implemented. + +### Version and metadata + +The `IsolatedFromAbove` parent of the ops can contain the following +metadata: + +* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to + `ModelProto.ir_version`. +* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to + `ModelProto.producer_name`. +* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to + `ModelProto.producer_version`. +* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding + to `ModelProto.opset_import.version` for the domain "" (empty). + Will be ommitted if the default opset is not included. +* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr` + for each non default domain. + +Generally, the importer handles variations in `ir_version` whereas +the transformations here handle opset version differences. Version +independent transformations are encouraged where possible if there +are only minor variations of an op. Major variations should use +`since_version` sensitive patterns. + +### Special op forms + +Certain ONNX operators map to different structural components of +torch-mlir's representation: + +* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with + a corresponding `value` attribute. + diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 8956066b8769..d9030c23a66f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -14,16 +14,19 @@ set(LinkedLibs MLIRTosaDialect MLIRSupport - TorchMLIRTorchPasses - TorchMLIRTorchConversionDialect - + # Dialects. + TorchMLIRTMTensorDialect TorchMLIRTorchDialect - TorchMLIRTorchConversionPasses + TorchMLIRTorchConversionDialect + # Dialect passes. TorchMLIRTMTensorPasses - TorchMLIRTMTensorDialect + TorchMLIRTorchConversionPasses + TorchMLIRTorchPasses + # Conversion passes. TorchMLIRConversionPasses + TorchMLIRTorchOnnxToTorch ) if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index f26b4d6e895e..afbe775d3a20 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt new file mode 100644 index 000000000000..807db64eac64 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch + DefaultDomainAtoF.cpp + DefaultDomainGtoP.cpp + DefaultDomainQtoZ.cpp + Passes.cpp + Patterns.cpp + TorchOnnxToTorch.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch + + DEPENDS + TorchMLIRConversionTorchOnnxToTorchPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TorchMLIRTorchDialect +) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp new file mode 100644 index 000000000000..5bcf17a1fd92 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -0,0 +1,146 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainAtoF( + OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("Abs", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + // TODO: Acos unimplemented in torch-mlir + // TODO: Acosh unimplemented in torch-mlir + // Add became forward compatible with Torch in version 7. + patterns.onOp("Add", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs, const1); + return success(); + }); + // TODO: AffineGrid + patterns.onOp("And", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp( + "ArgMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + patterns.onOp( + "ArgMin", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + bool keepDims; + int64_t axis; + bool selectLastIndex; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64BoolAttr(keepDims, "keepdims", true) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.s64BoolAttr(selectLastIndex, "select_last_index", false)) + return failure(); + + if (selectLastIndex) { + // TODO: Figure out how to support this case. Need to add a reverse + // or something. + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: select_last_index=true"); + } + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(operand.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constKeepDims = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(keepDims)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAxis, constKeepDims); + return success(); + }); + // TODO: Asin unimplemented in torch-mlir + // TODO: Asinh unimplemented in torch-mlir + // TODO: Atan unimplemented in torch-mlir + // TODO: Atanh unimplemented in torch-mlir +} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp new file mode 100644 index 000000000000..af4f06fdef77 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainGtoP( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp new file mode 100644 index 000000000000..23af89f329ab --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -0,0 +1,29 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +// Simple rewrites for the default domain. +// See: https://onnx.ai/onnx/operators/ +// For operators that are effectively version invariant, we register with +// sinceVersion==1. We interpret this to include the following spec +// diffs that are irrelevant to this level of lowering: +// * Supported element types. +// * Limited broadcasting to full broadcasting support. +// +// There are a lot of spec revisions that basically generalized elementwise +// to be more normal and a direct translation vs a special case. This +// results in a lot of ONNX test cases that all reduce to the exact same +// thing here, so we simplify. +void mlir::torch::onnx_c::populateDefaultDomainQtoZ( + OnnxCustomOpConversionPattern &patterns) {} diff --git a/lib/Conversion/TorchOnnxToTorch/PassDetail.h b/lib/Conversion/TorchOnnxToTorch/PassDetail.h new file mode 100644 index 000000000000..bbcd3413c59c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/PassDetail.h @@ -0,0 +1,24 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H +#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::torch::onnx_c { + +#define GEN_PASS_CLASSES +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSDETAIL_H diff --git a/lib/Conversion/TorchOnnxToTorch/Passes.cpp b/lib/Conversion/TorchOnnxToTorch/Passes.cpp new file mode 100644 index 000000000000..1f8cb05fa02c --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Passes.cpp @@ -0,0 +1,19 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" + +namespace { +#define GEN_PASS_REGISTRATION +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h.inc" +} // end namespace + +void mlir::torch::onnx_c::registerTorchOnnxToTorchPasses() { + ::registerPasses(); +} diff --git a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp new file mode 100644 index 000000000000..6ca7824165d3 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp @@ -0,0 +1,57 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite( + Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto foundIt = namedHandlers.find(op.getNameAttr()); + if (foundIt == namedHandlers.end()) + return failure(); + auto ®gies = foundIt->second; + for (const HandlerReg ® : reggies) { + if (domainVersion < reg.sinceVersion) { + LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion + << ", for domainVersion=" << domainVersion << "\n"); + continue; + } + if (succeeded(reg.callback(OpBinder(op), rewriter))) { + return success(); + } else { + LLVM_DEBUG(dbgs() << ": conversion failed to apply: " << foundIt->first + << ", sinceVersion=" << reg.sinceVersion << "\n"); + } + } + return rewriter.notifyMatchFailure(op, "no matching versioned converter"); +} + +void OnnxCustomOpConversionPattern::populateLegalizedNames( + DenseSet &legalizedNames) { + for (auto it : namedHandlers) + legalizedNames.insert(it.first); +} + +void OnnxCustomOpConversionPattern::onOp(StringRef name, int64_t sinceVersion, + HandlerFn callback) { + SmallString<64> fullName(domainPrefix); + fullName.append(name); + StringAttr nameAttr = StringAttr::get(getContext(), fullName); + namedHandlers[nameAttr].push_back(HandlerReg(callback, sinceVersion)); +} diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp new file mode 100644 index 000000000000..ea890bf0f4b6 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -0,0 +1,87 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./PassDetail.h" +#include "mlir/Support/LLVM.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::dbgs; +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +#define DEBUG_TYPE "torch-onnx" + +namespace { + +int64_t getDefaultOpsetVersion(Operation *containerOp) { + auto attr = + containerOp->getAttrOfType("torch.onnx_meta.opset_version"); + if (!attr) + return 0; + if (auto type = dyn_cast(attr.getType())) { + if (!type || !type.isSigned()) + return 0; + } + return attr.getSInt(); +} + +class ConvertTorchOnnxToTorch + : public ConvertTorchOnnxToTorchBase { +public: + ConvertTorchOnnxToTorch() = default; + void runOnOperation() override { + MLIRContext *context = &getContext(); + + // Populate our patterns for each handled domain. + int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation()); + if (defaultOpsetVersion == 0) { + emitError(getOperation().getLoc()) + << "function is missing onnx opset version attribute " + "(torch.onnx_meta.opset_version)"; + return signalPassFailure(); + } + + auto defaultDomainPatterns = + std::make_unique( + context, "onnx.", + /*domainVersion=*/defaultOpsetVersion); + populateDefaultDomainAtoF(*defaultDomainPatterns); + populateDefaultDomainGtoP(*defaultDomainPatterns); + populateDefaultDomainQtoZ(*defaultDomainPatterns); + + // Ask each domain for its handled names and configure the + // conversion target. + ConversionTarget target(*context); + DenseSet legalizedNames; + defaultDomainPatterns->populateLegalizedNames(legalizedNames); + target.addLegalDialect(); + target.addDynamicallyLegalOp([&](Torch::OperatorOp op) { + return !legalizedNames.contains(op.getNameAttr()); + }); + + RewritePatternSet patterns(context); + patterns.insert(std::move(defaultDomainPatterns)); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::onnx_c::createTorchOnnxToTorchPass() { + return std::make_unique(); +} diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 0be0ec8ba3ea..ace6c1a40e74 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -22,6 +22,7 @@ #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" #include "torch-mlir/Conversion/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" @@ -47,8 +48,8 @@ void mlir::torch::registerOptionalInputDialects( void mlir::torch::registerAllPasses() { mlir::torch::registerTorchPasses(); mlir::torch::registerTorchConversionPasses(); - mlir::torch::registerConversionPasses(); + mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir new file mode 100644 index 000000000000..e2123ac5e057 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -0,0 +1,97 @@ +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: func.func @test_abs +func.func @test_abs(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Abs"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add +func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_bcast +func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_add_uint8 +func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %[[INT1]] : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>, !torch.int -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// CHECK-LABEL: @test_and_bcast3v1d +func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.And"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_argmax_default_axis_example +func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmax_negative_axis_keepdims_example +func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmax_no_keepdims_example +func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmax %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: @test_argmin_default_axis_example +func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 0 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> + return %0 : !torch.vtensor<[1,2],si64> +} + +// CHECK-LABEL: @test_argmin_negative_axis_keepdims_example +func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool true + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2,1],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: @test_argmin_no_keepdims_example +func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 1 + // CHECK: %[[BOOL:.*]] = torch.constant.bool false + // CHECK: torch.aten.argmin %arg0, %[[INT]], %[[BOOL]] : !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir new file mode 100644 index 000000000000..22d5e2d35183 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch + +module { + func.func @test_argmax_no_keepdims_random_select_last_index(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMax"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,4],si64> + return %0 : !torch.vtensor<[2,4],si64> + } +} + +// ----- +func.func @test_argmin_no_keepdims_example_select_last_index(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // TODO: Unsupported torch.onnx.select_last_index + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %0 = torch.operator "onnx.ArgMin"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.keepdims = 0 : si64, torch.onnx.select_last_index = 1 : si64} : (!torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +}