Skip to content

Commit

Permalink
Initial TorchOnnxToTorch conversion pipeline. (#2585)
Browse files Browse the repository at this point in the history
Adds a pipeline to convert custom ops and metadata represented as
`torch.operator` custom ops to corresponding `torch` ops where possible.

This is part of a multi-part approach for building ONNX import in as a
regular feature of torch-mlir. It is focused on the conversions vs the
infra. We will end up maintaining a [pure-python
importer](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
to go with this in torch-mlir, and we will also maintain test case
generation utilities derived from it.

I have left substantial documentation in the README of the conversion
directory, including the recommended approach that we will take to keep
building this out.

(note that this organizes the code to coincide with the refactoring in
#2442 versus the current flat arrangement)
  • Loading branch information
stellaraccident authored Nov 22, 2023
1 parent d50d3aa commit e06efc5
Show file tree
Hide file tree
Showing 19 changed files with 897 additions and 6 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(TorchOnnxToTorch)

set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
Expand Down
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(TorchMLIRConversionTorchOnnxToTorchPassIncGen)
add_mlir_doc(Passes TorchMLIRConversionTorchOnnxToTorchPasses ./ -gen-pass-doc)
27 changes: 27 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir::torch::onnx_c {

std::unique_ptr<OperationPass<func::FuncOp>> createTorchOnnxToTorchPass();

/// Registers all torch-mlir conversion passes.
void registerTorchOnnxToTorchPasses();

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_H
26 changes: 26 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES
#define TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES

include "mlir/Pass/PassBase.td"

def ConvertTorchOnnxToTorch : Pass<"convert-torch-onnx-to-torch", "func::FuncOp"> {
let summary = "Converts ONNX custom ops in the torch dialect to native torch ops";
let description = [{
Converts equivalent ONNX custom ops to built-in equivalents.

See the README for a detailed description of how this operates.
}];

let constructor = "mlir::torch::onnx_c::createTorchOnnxToTorchPass()";
}

#endif // TORCHMLIR_CONVERSION_TORCHONNX_TO_TORCH_PASSES
169 changes: 169 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H
#define TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::torch::onnx_c {

/// Used during ONNX pattern matching to bind common patterns of operands,
/// result types and attributes to local variables in a way that is easy
/// to fail the pattern if constraints are violated. Most methods return
/// a ParseResult, which allows for chaining like:
///
/// if (binder.tensorOperand(foo) || binder.tensorResultType(t))
/// return failure();
struct OpBinder {
OpBinder(Operation *op) : op(op) {}

Location getLoc() { return op->getLoc(); }

// Operand matches of different arities.
ParseResult tensorOperand(Value &value0) {
if (op->getNumOperands() != 1)
return failure();
value0 = op->getOperand(0);
if (!toValidTensorType(value0.getType()))
return failure();
return success();
}

ParseResult tensorOperands(Value &value0, Value &value1) {
if (op->getNumOperands() != 2)
return failure();
value0 = op->getOperand(0);
value1 = op->getOperand(1);
if (!toValidTensorType(value0.getType()) ||
!toValidTensorType(value1.getType()))
return failure();
return success();
}

// Result type matchers of different arities.
ParseResult tensorResultType(Torch::ValueTensorType &type0) {
if (op->getNumResults() != 1)
return failure();
auto t = toValidTensorType(op->getResult(0).getType());
if (!t)
return failure();
type0 = t;
return success();
}

// Attribute accessors.
ParseResult s64BoolAttr(bool &value, StringRef nameSuffix,
bool defaultValue = false) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
value = defaultValue;
return success();
}
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
IntegerType t = cast<IntegerType>(integerAttr.getType());
if (!t.isSigned() || t.getWidth() != 64)
return failure();
value = static_cast<bool>(integerAttr.getSInt());
return success();
}
return failure();
}

ParseResult s64IntegerAttr(int64_t &value, StringRef nameSuffix,
int64_t defaultValue = 0) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
value = defaultValue;
return success();
}
if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
IntegerType t = cast<IntegerType>(integerAttr.getType());
if (!t.isSigned() || t.getWidth() != 64)
return failure();
value = integerAttr.getSInt();
return success();
}
return failure();
}

Torch::ValueTensorType toValidTensorType(Type t) {
auto tt = dyn_cast<Torch::ValueTensorType>(t);
if (tt && tt.hasSizes())
return tt;
return {};
}

Operation *op;
};

/// We use a single pattern per ONNX domain to handle all named custom
/// ops.
/// This allows us to avoid the n^2 problem on pattern application by
/// implementing a secondary index based on the name and sinceVersion
/// attributes.
/// It also lets us add some ergonomics for trivial cases.
class OnnxCustomOpConversionPattern
: public OpConversionPattern<Torch::OperatorOp> {
public:
using HandlerFn = LogicalResult (*)(OpBinder binder,
ConversionPatternRewriter &rewriter);
struct HandlerReg {
HandlerReg(HandlerFn callback, int64_t sinceVersion)
: callback(callback), sinceVersion(sinceVersion) {}
HandlerFn callback;
int64_t sinceVersion;
};

OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix,
int64_t domainVersion)
: OpConversionPattern(context), domainPrefix(std::move(domainPrefix)),
domainVersion(domainVersion) {}

LogicalResult
matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

/// Adds all fully qualified operator names to the given set.
/// This is typically used for implementing a dynamic legality
/// check for torch.operator names.
void populateLegalizedNames(DenseSet<StringAttr> &legalizedNames);

/// Register a conversion for a specific ONNX operator. For the
/// default domain, this is the canonical ONNX operator name (i.e.
/// "Acos").
/// Multiple conversions can be registered for the same op, most
/// commonly differing by their `sinceVersion`.
void onOp(StringRef name, int64_t sinceVersion, HandlerFn callback);

private:
std::string domainPrefix;
int64_t domainVersion;
DenseMap<StringAttr, SmallVector<HandlerReg, 1>> namedHandlers;
};

// Patterns are split into chunks to speed compile time and reduce some
// contention on the same source files.
void populateDefaultDomainAtoF(OnnxCustomOpConversionPattern &patterns);
void populateDefaultDomainGtoP(OnnxCustomOpConversionPattern &patterns);
void populateDefaultDomainQtoZ(OnnxCustomOpConversionPattern &patterns);

} // namespace mlir::torch::onnx_c

#endif // TORCHMLIR_CONVERSION_TORCHONNX_CONVERSION_UTILS_H
133 changes: 133 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# TorchOnnx To Torch Conversions

We enable the direct representation of many ONNX features directly in
the `torch` dialect as `torch.operator` custom ops with names like
`onnx.{OperatorName}`. The majority of ONNX operators are represented
with a systematic transformation. See
[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py)
for the reference importer which complies with the rules below
(this is planned to be upstreamed to torch-mlir proper in the near
future).

## Adding new ONNX operators

With the exception of certain special or complicated ONNX operators, most
are relatively straight-forward to map, following this general procedure:

* Plan the ops you wish to support by consulting the
[ONNX operator database](https://onnx.ai/onnx/operators/).
* This database has detailed diffs wrt different support versions but
at the level of detail we operate, most version diffs are inconsequential
and just require a bit more pattern support.
* This typically applies to generalization of broadcasting semantics,
expanded type support, and other things of the like.
* *Prerequisite*: Add support for the op to torch-mlir if it does not
already exist.
* Open the corresponding implementation file `DefaultDomainXtoY.cpp`
corresponding with the alphabetic sort of the op and add a conversion.
* Generate successful test cases:
* Either run the Turbine importer to produce MLIR output for all
ops/models in the ONNX test suite or use a dump that someone has
generated:
* [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing)
* There are often many variants of tests for checking conformance of
different historic ONNX encodings, but these are often not load bearing
at the MLIR level.
* Pick a handful of test cases and add them to
`test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an
alphabetic breakdown. At this time, ignore tests that are not exercising
useful differences in the pattern implementations.
* Generate failure test cases:
* Some ops have forms that do not (easily) map to torch-mlir. If you leave
an op under-implemented, add a failing test case to
`test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`.
* Optional but recommended: Use your test case files to fuzz against the
torch-mlir backend of your choice by running a backend conversion pipeline
and fixing any crashes/issues.
* Send a patch with your changes.

## ONNX proto to `torch` dialect mapping

### Type Conversion

* Tensors: ONNX tensor types are converted to `torch.vtensor`
with static and dynamic dimensions. We require that shape
inference has run to produce ranked tensors.
* Tensor element types are directly converted to corresponding
MLIR types as used by the rest of torch-mlir.
* String, sequence and sparse tensor types are presently not mapped.

### Attributes

A subset of attributes types are converted directly to an attribute
dict on the op with a name like `torch.onnx.{AttributeName}`. The
following attribute type mappings are made:

* `FLOAT`: `FloatAttr`
* `INT`: Signed `IntegerAttr` of width 64
* `STRING`: `StringAttr`
* `TENSOR`: Converted to one of:
* `DenseResourceElementsAttr` for inlined `raw_data`
* `DenseElementsAttr` for splats
* `DenseElementsAttr` for inlined typed proto initialization
* `FLOATS`: `ArrayAttr` of `FloatAttr`
* `INTS`: `ArrayAttr` of signed `IntegerAttr` of width 64
* `STRINGS`: `ArrayAttr` of `StringAttr`
* `TENSORS`: `ArrayAttr` of corresponding `TENSOR` conversion

The following attribute types have no present, systematic conversion.
Their presence on an op indicates that the op is a special form, which
must be handled specially:

* `GRAPH`
* `SPARSE_TENSOR` (TBD: it is possible to handle this systematically if
useful).
* `TYPE_PROTO` (TBD: it may be possible to handle this systematically if
useful).
* Plural equivalents of the above.

### Default operation conversion

Operations are converted to a `torch.operator` with name `onnx.{OperatorName}`.
The constraint that the ONNX graph is topologically sorted and free of
cycles matches the SSA form. Operands and results are mapped directly.

This conversion only applies to the default (empty) domain.

### Quantization information

Quantization parameters are carried out of line in the ONNX protobuf
and will be repatriated upon import to torch. The exact mechanism is
not yet implemented.

### Version and metadata

The `IsolatedFromAbove` parent of the ops can contain the following
metadata:

* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to
`ModelProto.ir_version`.
* `torch.onnx_meta.producer_name`: `StringAttr` corresponding to
`ModelProto.producer_name`.
* `torch.onnx_meta.producer_version`: `StringAttr` corresponding to
`ModelProto.producer_version`.
* `torch.onnx_meta.opset_version`: 64bit `IntegerAttr` corresponding
to `ModelProto.opset_import.version` for the domain "" (empty).
Will be ommitted if the default opset is not included.
* `torch.onnx_meta.opset_versions`: DictAttr of 64bit `IntegerAttr`
for each non default domain.

Generally, the importer handles variations in `ir_version` whereas
the transformations here handle opset version differences. Version
independent transformations are encouraged where possible if there
are only minor variations of an op. Major variations should use
`since_version` sensitive patterns.

### Special op forms

Certain ONNX operators map to different structural components of
torch-mlir's representation:

* `ConstantOfShape`: Mapped to `torch.vtensor.literal` with
a corresponding `value` attribute.

13 changes: 8 additions & 5 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ set(LinkedLibs
MLIRTosaDialect
MLIRSupport

TorchMLIRTorchPasses
TorchMLIRTorchConversionDialect

# Dialects.
TorchMLIRTMTensorDialect
TorchMLIRTorchDialect
TorchMLIRTorchConversionPasses
TorchMLIRTorchConversionDialect

# Dialect passes.
TorchMLIRTMTensorPasses
TorchMLIRTMTensorDialect
TorchMLIRTorchConversionPasses
TorchMLIRTorchPasses

# Conversion passes.
TorchMLIRConversionPasses
TorchMLIRTorchOnnxToTorch
)

if(TORCH_MLIR_ENABLE_REFBACKEND)
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(TorchOnnxToTorch)
add_subdirectory(TorchToLinalg)
add_subdirectory(TorchToSCF)
add_subdirectory(TorchToArith)
Expand Down
Loading

0 comments on commit e06efc5

Please sign in to comment.