Skip to content

Commit

Permalink
Add center_point_box=1 support in NonMaxSuppression. (#3976)
Browse files Browse the repository at this point in the history
When center_point_box=1, the supplied boxes come with a format of
[x_center, y_center, width, height], this patch converts the format into
[x1, y1, x2, y2] format before they are consumed.

The e2e test is added in nod-ai/SHARK-TestSuite#436
  • Loading branch information
lpy authored Jan 22, 2025
1 parent 481da8d commit 2564d7a
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 4 deletions.
49 changes: 45 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3697,11 +3697,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();

// TODO: Add support for non-zero center_point_box value.
if (centerPointBox != 0)
if (centerPointBox != 0 && centerPointBox != 1)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");
binder.op, "expected center_point_box attribute to be 0 or 1");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
Expand All @@ -3727,6 +3725,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();
if (centerPointBox == 1) {
// When center_point_box is 1, the box data is supplied as
// [[x_center, y_center, width, height], ...]. Slice it to
// [[x_center, y_center], ...] and [[width, height], ...],
// calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate
// to [[x1, y1, x2, y2], ...]
auto boxesTensorType =
dyn_cast<Torch::ValueTensorType>(boxes.getType());
Value const0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value const1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value const2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value const4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));
Value const2F = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));

// extract scaled ranges for regions of interest
auto sliceShape = SmallVector<int64_t>{Torch::kUnknownSize, 2};
auto sliceTensorType = rewriter.getType<Torch::ValueTensorType>(
sliceShape, boxesTensorType.getDtype());
Value centers = rewriter.create<Torch::AtenSliceTensorOp>(
loc, sliceTensorType, boxes, const1, const0, const2, const1);
Value sizes = rewriter.create<Torch::AtenSliceTensorOp>(
loc, sliceTensorType, boxes, const1, const2, const4, const1);
Value halfSizes = rewriter.create<Torch::AtenDivScalarOp>(
loc, sizes.getType(), sizes, const2F);
Value x1y1s = rewriter.create<Torch::AtenSubTensorOp>(
loc, centers.getType(), centers, halfSizes, const1);
Value x2y2s = rewriter.create<Torch::AtenAddTensorOp>(
loc, centers.getType(), centers, halfSizes, const1);

Type listElemType = boxesTensorType.getWithSizesAndDtype(
/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, SmallVector<Value>{x1y1s, x2y2s});
boxes = rewriter.create<Torch::AtenCatOp>(loc, boxesTensorType,
tensorList, const1);
}

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
Expand Down
69 changes: 69 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,75 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
return %0 : !torch.vtensor<[1,3],si64>
}

// CHECK-LABEL: func.func @test_nonmaxsuppression_center_point_box(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[VAL_20:.*]] = torch.constant.int 0
// CHECK: %[[VAL_21:.*]] = torch.constant.int 1
// CHECK: %[[VAL_22:.*]] = torch.constant.int 2
// CHECK: %[[VAL_23:.*]] = torch.constant.int 4
// CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00
// CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
// CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
// CHECK: %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32>
// CHECK: %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
// CHECK: %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
// CHECK: %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list<vtensor>
// CHECK: %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,4],f32>
// CHECK: %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
// CHECK: %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float
// CHECK: %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)"
// CHECK: %[[VAL_36:.*]] = torch.constant.int 0
// CHECK: %[[VAL_37:.*]] = torch.constant.int 1
// CHECK: %[[VAL_38:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
// CHECK: %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) {
// CHECK: %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64>
// CHECK: } else {
// CHECK: %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
// CHECK: torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64>
// CHECK: }
// CHECK: %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
// CHECK: %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
// CHECK: %[[VAL_49:.*]] = torch.constant.int 2
// CHECK: %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_51:.*]] = torch.constant.none
// CHECK: %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
// CHECK: %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
// CHECK: return %[[VAL_54]] : !torch.vtensor<[1,3],si64>
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
return %0 : !torch.vtensor<[1,3],si64>
}
// -----

// CHECK-LABEL: func.func @test_mwm
Expand Down

0 comments on commit 2564d7a

Please sign in to comment.