diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index c18d681055aa..4df2e0f88eb1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1364,6 +1364,79 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, data, dimValueList); return success(); }); + patterns.onOp( + "Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Flatten means to partition the input tensor's dimensions + // into a "left range" spanning 0 to axis - 1 and a "right range" + // spanning axis to rank - 1. Each range is then collapsed + // into a single dimension, resulting in a 2-D tensor. + // If either range is empty, it is replaced with a single + // dimension of size 1. + // + // For example, for a 4-D input tensor of shape (a, b, c, d) + // and axis==2, flatten produces a 2-D tensor of shape + // (a*b, c*d). + // + // If instead axis==0, the left range is empty, and the result + // is (1, a*b*c*d). + + Torch::ValueTensorType resultType; + Value operand; + int64_t axis; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(axis, "axis", 1) || + binder.tensorResultType(resultType)) + return failure(); + + // If axis is negative, count from the right instead of left + int64_t rank = + cast(operand.getType()).getSizes().size(); + if (axis < 0) + axis = rank + axis; + + Value collapsedRight; + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + + if (axis >= rank) { + // If the right range is empty, add a dim of size 1 to the + // right side of the shape: + // cr = torch.unsqueeze(x, x.ndim) + Value rankConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, rankConst); + } else { + // Otherwise, collapse the right range into a single dimension: + // cr = torch._prims.collapse(x, axis, x.ndim - 1) + Value axisConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value rankLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, axisConst, rankLess1Const); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + if (axis <= 0) { + // If the left range is empty, add a dim of size 1 to the + // left side of the shape: + // torch.unsqueeze(cr, 0) + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero); + return success(); + } + + // Otherwise, collapse the left range into a single dimension: + // torch._prims.collapse(cr, 0, axis - 1) + Value axisLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero, axisLess1Const); + return success(); + }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 42a0fe743bc2..c2d3c12a7b92 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m return %0 : !torch.vtensor<[2],si64> } +// CHECK-LABEL: @test_flatten_4d_axis_2 +func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_0 +func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_4 +func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> + return %0 : !torch.vtensor<[120,1],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_2 +func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_1 +func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> + return %0 : !torch.vtensor<[24,5],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_4 +func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// CHECK-LABEL: @test_flatten_2d_axis_1 +func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// CHECK-LABEL: @test_flatten_1d_axis_0 +func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// CHECK-LABEL: @test_flatten_1d_axis_negative_1 +func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// COM: CHECK-LABEL: @test_flatten_1d_axis_1 +func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> + return %0 : !torch.vtensor<[2,1],f32> +}