diff --git a/plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginConfig.yaml b/plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginConfig.yaml index 074308b6..3ce2e849 100644 --- a/plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginConfig.yaml +++ b/plugin/efficientNMSPlugin/EfficientNMSPlugin_PluginConfig.yaml @@ -9,6 +9,7 @@ versions: - max_output_boxes - background_class - score_activation + - class_agnostic - box_coding attribute_types: score_threshold: float32 @@ -16,6 +17,7 @@ versions: max_output_boxes: int32 background_class: int32 score_activation: int32 + class_agnostic: int32 box_coding: int32 attribute_length: score_threshold: 1 @@ -23,6 +25,7 @@ versions: max_output_boxes: 1 background_class: 1 score_activation: 1 + class_agnostic: 1 box_coding: 1 attribute_options: score_threshold: @@ -40,6 +43,9 @@ versions: score_activation: - 0 - 1 + class_agnostic: + - 0 + - 1 box_coding: - 0 - 1 diff --git a/plugin/efficientNMSPlugin/README.md b/plugin/efficientNMSPlugin/README.md index cb8603de..8a972b36 100644 --- a/plugin/efficientNMSPlugin/README.md +++ b/plugin/efficientNMSPlugin/README.md @@ -98,6 +98,7 @@ The following four output tensors are generated: |`int` |`max_output_boxes` |The maximum number of detections to output per image. |`int` |`background_class` |The label ID for the background class. If there is no background class, set it to `-1`. |`bool` |`score_activation` * |Set to true to apply sigmoid activation to the confidence scores during NMS operation. +|`bool` |`class_agnostic` |Set to true to do class-independent NMS; otherwise, boxes of different classes would be considered separately during NMS. |`int` |`box_coding` |Coding type used for boxes (and anchors if applicable), 0 = BoxCorner, 1 = BoxCenterSize. Parameters marked with a `*` have a non-negligible effect on runtime latency. See the [Performance Tuning](#performance-tuning) section below for more details on how to set them optimally. @@ -134,6 +135,10 @@ The algorithm is highly sensitive to the selected `score_threshold` parameter. W Depending on network configuration, it is usually more efficient to provide raw scores (pre-sigmoid) to the NMS plugin scores input, and enable the `score_activation` parameter. Doing so applies a sigmoid activation only to the last `max_output_boxes` selected scores, instead of all the predicted scores, largely reducing the computational cost. +#### Class Independent NMS + +Some object detection networks/architectures like YOLO series need to use class-independent NMS operations. If `class_agnostic` is enabled, class-independent NMS is performed; otherwise, different classes would do NMS separately. + #### Using the Fused Box Decoder When using networks with many anchors, such as EfficientDet or SSD, it may be more efficient to do box decoding within the NMS plugin. For this, pass the raw box predictions as the boxes input, and the default anchor coordinates as the optional third input to the plugin. diff --git a/plugin/efficientNMSPlugin/efficientNMSInference.cu b/plugin/efficientNMSPlugin/efficientNMSInference.cu index 28135b8c..3cf7e8ef 100644 --- a/plugin/efficientNMSPlugin/efficientNMSInference.cu +++ b/plugin/efficientNMSPlugin/efficientNMSInference.cu @@ -314,12 +314,16 @@ __global__ void EfficientNMS(EfficientNMSParameters param, const int* topNumData for (int tile = 0; tile < numTiles; tile++) { + bool ignoreClass = true; + if (!param.classAgnostic) + ignoreClass = threadClass[tile] == testClass; + // IOU if (boxIdx[tile] > i && // Make sure two different boxes are being tested, and that it's a higher index; boxIdx[tile] < numSelectedBoxes && // Make sure the box is within numSelectedBoxes; blockState == 1 && // Signal that allows IOU checks to be performed; threadState[tile] == 0 && // Make sure this box hasn't been either dropped or kept already; - threadClass[tile] == testClass && // Compare only boxes of matching classes; + ignoreClass && // Compare only boxes of matching classes when classAgnostic is false; lte_mp(threadScore[tile], testScore) && // Make sure the sorting order of scores is as expected; IOU(param, threadBox[tile], testBox) >= param.iouThreshold) // And... IOU overlap. { diff --git a/plugin/efficientNMSPlugin/efficientNMSParameters.h b/plugin/efficientNMSPlugin/efficientNMSParameters.h index 216455bb..9cc4e6a6 100644 --- a/plugin/efficientNMSPlugin/efficientNMSParameters.h +++ b/plugin/efficientNMSPlugin/efficientNMSParameters.h @@ -37,6 +37,7 @@ struct EfficientNMSParameters bool scoreSigmoid = false; bool clipBoxes = false; int boxCoding = 0; + bool classAgnostic = false; // Related to NMS Internals int numSelectedBoxes = 4096; diff --git a/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp b/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp index 6edbd3d6..ff663257 100644 --- a/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp +++ b/plugin/efficientNMSPlugin/efficientNMSPlugin.cpp @@ -428,6 +428,7 @@ EfficientNMSPluginCreator::EfficientNMSPluginCreator() mPluginAttributes.emplace_back(PluginField("max_output_boxes", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("background_class", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("score_activation", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("class_agnostic", nullptr, PluginFieldType::kINT32, 1)); mPluginAttributes.emplace_back(PluginField("box_coding", nullptr, PluginFieldType::kINT32, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); @@ -493,6 +494,10 @@ IPluginV2DynamicExt* EfficientNMSPluginCreator::createPlugin(const char* name, c PLUGIN_VALIDATE(scoreSigmoid == 0 || scoreSigmoid == 1); mParam.scoreSigmoid = static_cast(scoreSigmoid); } + if (!strcmp(attrName, "class_agnostic")) + { + mParam.classAgnostic = *(static_cast(fields[i].data)); + } if (!strcmp(attrName, "box_coding")) { PLUGIN_VALIDATE(fields[i].type == PluginFieldType::kINT32); diff --git a/samples/python/detectron2/create_onnx.py b/samples/python/detectron2/create_onnx.py index 90b29386..8e2c1e8e 100644 --- a/samples/python/detectron2/create_onnx.py +++ b/samples/python/detectron2/create_onnx.py @@ -289,6 +289,7 @@ def NMS(self, boxes, scores, anchors, background_class, score_activation, max_pr 'score_threshold': max(0.01, score_threshold), 'iou_threshold': iou_threshold, 'score_activation': score_activation, + 'class_agnostic': False, 'box_coding': 1, } ) diff --git a/samples/python/efficientdet/create_onnx.py b/samples/python/efficientdet/create_onnx.py index 01897192..0c66620c 100644 --- a/samples/python/efficientdet/create_onnx.py +++ b/samples/python/efficientdet/create_onnx.py @@ -386,6 +386,7 @@ def get_anchor_np(output_idx, op): 'score_threshold': max(0.01, score_threshold), # Keep threshold to at least 0.01 for better efficiency 'iou_threshold': iou_threshold, 'score_activation': True, + 'class_agnostic': False, 'box_coding': 1, } nms_output_classes_dtype = np.int32 diff --git a/samples/python/tensorflow_object_detection_api/create_onnx.py b/samples/python/tensorflow_object_detection_api/create_onnx.py index 35b7064d..3ecb1b93 100644 --- a/samples/python/tensorflow_object_detection_api/create_onnx.py +++ b/samples/python/tensorflow_object_detection_api/create_onnx.py @@ -367,6 +367,7 @@ def NMS(self, box_net_tensor, class_net_tensor, anchors_tensor, background_class 'score_threshold': max(0.01, score_threshold), 'iou_threshold': iou_threshold, 'score_activation': score_activation, + 'class_agnostic': False, 'box_coding': 1, } )