-
Notifications
You must be signed in to change notification settings - Fork 519
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial TorchOnnxToTorch conversion pipeline. (#2585)
Adds a pipeline to convert custom ops and metadata represented as `torch.operator` custom ops to corresponding `torch` ops where possible. This is part of a multi-part approach for building ONNX import in as a regular feature of torch-mlir. It is focused on the conversions vs the infra. We will end up maintaining a [pure-python importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) to go with this in torch-mlir, and we will also maintain test case generation utilities derived from it. I have left substantial documentation in the README of the conversion directory, including the recommended approach that we will take to keep building this out. (note that this organizes the code to coincide with the refactoring in #2442 versus the current flat arrangement)
- Loading branch information
1 parent
d50d3aa
commit e06efc5
Showing
19 changed files
with
897 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls) | ||
add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen) | ||
add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H | ||
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include <memory> | ||
|
||
namespace mlir::torch::onnx_c { | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass(); | ||
|
||
/// Registers all torch-mlir conversion passes. | ||
void registerTorchOnnxToTorchPasses(); | ||
|
||
} // namespace mlir::torch::onnx_c | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES | ||
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> { | ||
let summary = "Converts ONNX custom ops in the torch dialect to native torch ops"; | ||
let description = [{ | ||
Converts equivalent ONNX custom ops to built-in equivalents. | ||
|
||
See the README for a detailed description of how this operates. | ||
}]; | ||
|
||
let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()"; | ||
} | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES |
169 changes: 169 additions & 0 deletions
169
include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H | ||
#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H | ||
|
||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/SmallString.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
|
||
namespace mlir::torch::onnx_c { | ||
|
||
/// Used during ONNX pattern matching to bind common patterns of operands, | ||
/// result types and attributes to local variables in a way that is easy | ||
/// to fail the pattern if constraints are violated. Most methods return | ||
/// a ParseResult, which allows for chaining like: | ||
/// | ||
/// if (binder.tensorOperand(foo) || binder.tensorResultType(t)) | ||
/// return failure(); | ||
struct OpBinder { | ||
OpBinder(Operation *op) : op(op) {} | ||
|
||
Location getLoc() { return op->getLoc(); } | ||
|
||
// Operand matches of different arities. | ||
ParseResult tensorOperand(Value &value0) { | ||
if (op->getNumOperands() != 1) | ||
return failure(); | ||
value0 = op->getOperand(0); | ||
if (!toValidTensorType(value0.getType())) | ||
return failure(); | ||
return success(); | ||
} | ||
|
||
ParseResult tensorOperands(Value &value0, Value &value1) { | ||
if (op->getNumOperands() != 2) | ||
return failure(); | ||
value0 = op->getOperand(0); | ||
value1 = op->getOperand(1); | ||
if (!toValidTensorType(value0.getType()) || | ||
!toValidTensorType(value1.getType())) | ||
return failure(); | ||
return success(); | ||
} | ||
|
||
// Result type matchers of different arities. | ||
ParseResult tensorResultType(Torch::ValueTensorType &type0) { | ||
if (op->getNumResults() != 1) | ||
return failure(); | ||
auto t = toValidTensorType(op->getResult(0).getType()); | ||
if (!t) | ||
return failure(); | ||
type0 = t; | ||
return success(); | ||
} | ||
|
||
// Attribute accessors. | ||
ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, | ||
bool defaultValue = false) { | ||
SmallString<64> name("torch.onnx."); | ||
name.append(nameSuffix); | ||
auto attr = op->getAttr(name); | ||
if (!attr) { | ||
value = defaultValue; | ||
return success(); | ||
} | ||
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) { | ||
IntegerType t = cast<IntegerType>(integerAttr.getType()); | ||
if (!t.isSigned() || t.getWidth() != 64) | ||
return failure(); | ||
value = static_cast<bool>(integerAttr.getSInt()); | ||
return success(); | ||
} | ||
return failure(); | ||
} | ||
|
||
ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix, | ||
int64_t defaultValue = 0) { | ||
SmallString<64> name("torch.onnx."); | ||
name.append(nameSuffix); | ||
auto attr = op->getAttr(name); | ||
if (!attr) { | ||
value = defaultValue; | ||
return success(); | ||
} | ||
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) { | ||
IntegerType t = cast<IntegerType>(integerAttr.getType()); | ||
if (!t.isSigned() || t.getWidth() != 64) | ||
return failure(); | ||
value = integerAttr.getSInt(); | ||
return success(); | ||
} | ||
return failure(); | ||
} | ||
|
||
Torch::ValueTensorType toValidTensorType(Type t) { | ||
auto tt = dyn_cast<Torch::ValueTensorType>(t); | ||
if (tt && tt.hasSizes()) | ||
return tt; | ||
return {}; | ||
} | ||
|
||
Operation *op; | ||
}; | ||
|
||
/// We use a single pattern per ONNX domain to handle all named custom | ||
/// ops. | ||
/// This allows us to avoid the n^2 problem on pattern application by | ||
/// implementing a secondary index based on the name and sinceVersion | ||
/// attributes. | ||
/// It also lets us add some ergonomics for trivial cases. | ||
class OnnxCustomOpConversionPattern | ||
: public OpConversionPattern<Torch::OperatorOp> { | ||
public: | ||
using HandlerFn = LogicalResult (*)(OpBinder binder, | ||
ConversionPatternRewriter &rewriter); | ||
struct HandlerReg { | ||
HandlerReg(HandlerFn callback, int64_t sinceVersion) | ||
: callback(callback), sinceVersion(sinceVersion) {} | ||
HandlerFn callback; | ||
int64_t sinceVersion; | ||
}; | ||
|
||
OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, | ||
int64_t domainVersion) | ||
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), | ||
domainVersion(domainVersion) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override; | ||
|
||
/// Adds all fully qualified operator names to the given set. | ||
/// This is typically used for implementing a dynamic legality | ||
/// check for torch.operator names. | ||
void populateLegalizedNames(DenseSet<StringAttr> &legalizedNames); | ||
|
||
/// Register a conversion for a specific ONNX operator. For the | ||
/// default domain, this is the canonical ONNX operator name (i.e. | ||
/// "Acos"). | ||
/// Multiple conversions can be registered for the same op, most | ||
/// commonly differing by their `sinceVersion`. | ||
void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback); | ||
|
||
private: | ||
std::string domainPrefix; | ||
int64_t domainVersion; | ||
DenseMap<StringAttr, SmallVector<HandlerReg, 1>> namedHandlers; | ||
}; | ||
|
||
// Patterns are split into chunks to speed compile time and reduce some | ||
// contention on the same source files. | ||
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns); | ||
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns); | ||
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns); | ||
|
||
} // namespace mlir::torch::onnx_c | ||
|
||
#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H |
133 changes: 133 additions & 0 deletions
133
include/torch-mlir/Conversion/TorchOnnxToTorch/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# TorchOnnx To Torch Conversions | ||
|
||
We enable the direct representation of many ONNX features directly in | ||
the `torch` dialect as `torch.operator` custom ops with names like | ||
`onnx.{OperatorName}`. The majority of ONNX operators are represented | ||
with a systematic transformation. See | ||
[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) | ||
for the reference importer which complies with the rules below | ||
(this is planned to be upstreamed to torch-mlir proper in the near | ||
future). | ||
|
||
## Adding new ONNX operators | ||
|
||
With the exception of certain special or complicated ONNX operators, most | ||
are relatively straight-forward to map, following this general procedure: | ||
|
||
* Plan the ops you wish to support by consulting the | ||
[ONNX operator database](https://onnx.ai/onnx/operators/). | ||
* This database has detailed diffs wrt different support versions but | ||
at the level of detail we operate, most version diffs are inconsequential | ||
and just require a bit more pattern support. | ||
* This typically applies to generalization of broadcasting semantics, | ||
expanded type support, and other things of the like. | ||
* *Prerequisite*: Add support for the op to torch-mlir if it does not | ||
already exist. | ||
* Open the corresponding implementation file `DefaultDomainXtoY.cpp` | ||
corresponding with the alphabetic sort of the op and add a conversion. | ||
* Generate successful test cases: | ||
* Either run the Turbine importer to produce MLIR output for all | ||
ops/models in the ONNX test suite or use a dump that someone has | ||
generated: | ||
* [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) | ||
* There are often many variants of tests for checking conformance of | ||
different historic ONNX encodings, but these are often not load bearing | ||
at the MLIR level. | ||
* Pick a handful of test cases and add them to | ||
`test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an | ||
alphabetic breakdown. At this time, ignore tests that are not exercising | ||
useful differences in the pattern implementations. | ||
* Generate failure test cases: | ||
* Some ops have forms that do not (easily) map to torch-mlir. If you leave | ||
an op under-implemented, add a failing test case to | ||
`test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. | ||
* Optional but recommended: Use your test case files to fuzz against the | ||
torch-mlir backend of your choice by running a backend conversion pipeline | ||
and fixing any crashes/issues. | ||
* Send a patch with your changes. | ||
|
||
## ONNX proto to `torch` dialect mapping | ||
|
||
### Type Conversion | ||
|
||
* Tensors: ONNX tensor types are converted to `torch.vtensor` | ||
with static and dynamic dimensions. We require that shape | ||
inference has run to produce ranked tensors. | ||
* Tensor element types are directly converted to corresponding | ||
MLIR types as used by the rest of torch-mlir. | ||
* String, sequence and sparse tensor types are presently not mapped. | ||
|
||
### Attributes | ||
|
||
A subset of attributes types are converted directly to an attribute | ||
dict on the op with a name like `torch.onnx.{AttributeName}`. The | ||
following attribute type mappings are made: | ||
|
||
* `FLOAT`: `FloatAttr` | ||
* `INT`: Signed `IntegerAttr` of width 64 | ||
* `STRING`: `StringAttr` | ||
* `TENSOR`: Converted to one of: | ||
* `DenseResourceElementsAttr` for inlined `raw_data` | ||
* `DenseElementsAttr` for splats | ||
* `DenseElementsAttr` for inlined typed proto initialization | ||
* `FLOATS`: `ArrayAttr` of `FloatAttr` | ||
* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64 | ||
* `STRINGS`: `ArrayAttr` of `StringAttr` | ||
* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion | ||
|
||
The following attribute types have no present, systematic conversion. | ||
Their presence on an op indicates that the op is a special form, which | ||
must be handled specially: | ||
|
||
* `GRAPH` | ||
* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if | ||
useful). | ||
* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if | ||
useful). | ||
* Plural equivalents of the above. | ||
|
||
### Default operation conversion | ||
|
||
Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`. | ||
The constraint that the ONNX graph is topologically sorted and free of | ||
cycles matches the SSA form. Operands and results are mapped directly. | ||
|
||
This conversion only applies to the default (empty) domain. | ||
|
||
### Quantization information | ||
|
||
Quantization parameters are carried out of line in the ONNX protobuf | ||
and will be repatriated upon import to torch. The exact mechanism is | ||
not yet implemented. | ||
|
||
### Version and metadata | ||
|
||
The `IsolatedFromAbove` parent of the ops can contain the following | ||
metadata: | ||
|
||
* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to | ||
`ModelProto.ir_version`. | ||
* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to | ||
`ModelProto.producer_name`. | ||
* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to | ||
`ModelProto.producer_version`. | ||
* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding | ||
to `ModelProto.opset_import.version` for the domain "" (empty). | ||
Will be ommitted if the default opset is not included. | ||
* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr` | ||
for each non default domain. | ||
|
||
Generally, the importer handles variations in `ir_version` whereas | ||
the transformations here handle opset version differences. Version | ||
independent transformations are encouraged where possible if there | ||
are only minor variations of an op. Major variations should use | ||
`since_version` sensitive patterns. | ||
|
||
### Special op forms | ||
|
||
Certain ONNX operators map to different structural components of | ||
torch-mlir's representation: | ||
|
||
* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with | ||
a corresponding `value` attribute. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.