Skip to content

Commit

Permalink
support yolox-pose sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Aug 30, 2023
1 parent b7a69bc commit d820513
Show file tree
Hide file tree
Showing 17 changed files with 307 additions and 59 deletions.
9 changes: 9 additions & 0 deletions configs/mmpose/pose-detection_yolox-pose_sdk_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/sdk.py']

codebase_config = dict(model_type='sdk_yoloxpose')

backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='PoseToDetConverter'),
dict(type='PackDetPoseInputs')
])
14 changes: 14 additions & 0 deletions csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, i
for (int i = 0; i < count; ++i) {
delete[] results[i].point;
delete[] results[i].score;
delete[] results[i].bboxes;
delete[] results[i].bbox_score;
}
delete[] results;
}
Expand Down Expand Up @@ -156,16 +158,28 @@ int mmdeploy_pose_detector_get_result(mmdeploy_value_t output,
for (const auto& bbox_result : detections) {
auto& res = _results[result_idx++];
auto size = bbox_result.key_points.size();
auto num_bbox = bbox_result.detections.size();

res.point = new mmdeploy_point_t[size];
res.score = new float[size];
res.length = static_cast<int>(size);
res.bboxes = new mmdeploy_rect_t[num_bbox];
res.bbox_score = new float[num_bbox];
res.num_bbox = static_cast<int>(num_bbox);

for (int k = 0; k < size; k++) {
res.point[k].x = bbox_result.key_points[k].bbox[0];
res.point[k].y = bbox_result.key_points[k].bbox[1];
res.score[k] = bbox_result.key_points[k].score;
}
for (int k = 0; k < num_bbox; k++) {
res.bboxes[k].left = bbox_result.detections[k].boundingbox[0];
res.bboxes[k].top = bbox_result.detections[k].boundingbox[1];
res.bboxes[k].right = bbox_result.detections[k].boundingbox[2];
res.bboxes[k].bottom = bbox_result.detections[k].boundingbox[3];
res.bbox_score[k] = bbox_result.detections[k].score;
}

}

*results = _results.release();
Expand Down
4 changes: 4 additions & 0 deletions csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ extern "C" {
typedef struct mmdeploy_pose_detection_t {
mmdeploy_point_t* point; ///< keypoint
float* score; ///< keypoint score
mmdeploy_rect_t* bboxes; ///< bboxes
float* bbox_score; ///< bboxes score
int length; ///< number of keypoint
int num_bbox; ///< number of bboxes
} mmdeploy_pose_detection_t;

typedef struct mmdeploy_pose_detector* mmdeploy_pose_detector_t;
Expand Down Expand Up @@ -76,6 +79,7 @@ MMDEPLOY_API int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector,
* bboxes, must be release by \ref mmdeploy_pose_detector_release_result
* @return status code of the operation
*/

MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector,
const mmdeploy_mat_t* mats, int mat_count,
const mmdeploy_rect_t* bboxes,
Expand Down
44 changes: 31 additions & 13 deletions csrc/mmdeploy/apis/python/pose_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,38 @@ class PyPoseDetector {

auto output = py::list{};
auto result = detection;

for (int i = 0; i < mats.size(); i++) {
int n_point = result->length;
auto pred = py::array_t<float>({bbox_count[i], n_point, 3});
auto dst = pred.mutable_data();
for (int j = 0; j < bbox_count[i]; j++) {
int n_point_total = result->length;
int n_bbox = result->num_bbox;
int n_point = n_bbox > 0 ? n_point_total / n_bbox : 0;
int pts_ind = 0;
auto pred_pts = py::array_t<float>({n_bbox * n_point, 3});
auto pred_bbox = py::array_t<float>({n_bbox, 5});
auto dst_pts = pred_pts.mutable_data();
auto dst_bbox = pred_bbox.mutable_data();

// printf("num_bbox %d num_pts %d\n", result->num_bbox, result->length);
for (int j = 0; j < n_bbox; j++) {
for (int k = 0; k < n_point; k++) {
dst[0] = result->point[k].x;
dst[1] = result->point[k].y;
dst[2] = result->score[k];
dst += 3;
pts_ind = j * n_point + k;
dst_pts[0] = result->point[pts_ind].x;
dst_pts[1] = result->point[pts_ind].y;
dst_pts[2] = result->score[pts_ind];
dst_pts += 3;
// printf("pts %f %f %f\n", dst_pts[0], dst_pts[1], dst_pts[2]);

}
result++;
dst_bbox[0] = result->bboxes[j].left;
dst_bbox[1] = result->bboxes[j].top;
dst_bbox[2] = result->bboxes[j].right;
dst_bbox[3] = result->bboxes[j].bottom;
dst_bbox[4] = result->bbox_score[j];
// printf("box %f %f %f %f %f\n", dst_bbox[0], dst_bbox[1], dst_bbox[2], dst_bbox[3], dst_bbox[4]);
dst_bbox += 5;
}
output.append(std::move(pred));
result++;
output.append(py::make_tuple(std::move(pred_bbox), std::move(pred_pts)));
}

int total = std::accumulate(bbox_count.begin(), bbox_count.end(), 0);
Expand All @@ -101,12 +119,12 @@ static PythonBindingRegisterer register_pose_detector{[](py::module& m) {
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyPoseDetector* self, const PyImage& img) -> py::array {
[](PyPoseDetector* self, const PyImage& img) -> py::tuple {
return self->Apply({img}, {})[0];
})
.def(
"__call__",
[](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::array {
[](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::tuple {
std::vector<std::vector<Rect>> bboxes;
bboxes.push_back({box});
return self->Apply({img}, bboxes)[0];
Expand All @@ -115,7 +133,7 @@ static PythonBindingRegisterer register_pose_detector{[](py::module& m) {
.def(
"__call__",
[](PyPoseDetector* self, const PyImage& img,
const std::vector<Rect>& bboxes) -> py::array {
const std::vector<Rect>& bboxes) -> py::tuple {
std::vector<std::vector<Rect>> _bboxes;
_bboxes.push_back(bboxes);
return self->Apply({img}, _bboxes)[0];
Expand Down
5 changes: 5 additions & 0 deletions csrc/mmdeploy/codebase/mmpose/keypoints_from_heatmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
}

auto& img_metas = _data["img_metas"];
if (img_metas.contains("bbox")) {
from_value(img_metas["bbox"], bbox_);
}

vector<float> center;
vector<float> scale;
Expand All @@ -78,6 +81,7 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
output.key_points.push_back({{x, y}, s});
data += 3;
}
output.detections.push_back({{bbox_[0], bbox_[1], bbox_[2], bbox_[3]}, bbox_[4]});
return to_value(std::move(output));
}

Expand Down Expand Up @@ -354,6 +358,7 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
float valid_radius_factor_{0.0546875f};
bool use_udp_{false};
string target_type_{"GaussianHeatmap"};
vector<float> bbox_{0, 0, 1, 1, 1};
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapBaseHeadDecode);
Expand Down
7 changes: 6 additions & 1 deletion csrc/mmdeploy/codebase/mmpose/keypoints_from_regression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class DeepposeRegressionHeadDecode : public MMPose {
}

auto& img_metas = _data["img_metas"];

if (img_metas.contains("bbox")) {
from_value(img_metas["bbox"], bbox_);
}
vector<float> center;
vector<float> scale;
from_value(img_metas["center"], center);
Expand All @@ -60,6 +62,7 @@ class DeepposeRegressionHeadDecode : public MMPose {
output.key_points.push_back({{x, y}, s});
data += 3;
}
output.detections.push_back({{bbox_[0], bbox_[1], bbox_[2], bbox_[3]}, bbox_[4]});
return to_value(std::move(output));
}

Expand Down Expand Up @@ -106,6 +109,8 @@ class DeepposeRegressionHeadDecode : public MMPose {
*(data + 0) = *(data + 0) * scale_x + center[0] - scale[0] * 0.5;
*(data + 1) = *(data + 1) * scale_y + center[1] - scale[1] * 0.5;
}
private:
vector<float> bbox_{0, 0, 1, 1, 1};
};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMPose, DeepposeRegressionHeadDecode);
Expand Down
9 changes: 8 additions & 1 deletion csrc/mmdeploy/codebase/mmpose/mmpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ struct PoseDetectorOutput {
float score;
MMDEPLOY_ARCHIVE_MEMBERS(bbox, score);
};
struct BBox {
std::array<float, 4> boundingbox; // x1,y1,x2,y2
float score;
MMDEPLOY_ARCHIVE_MEMBERS(boundingbox, score);
};
std::vector<KeyPoint> key_points;
MMDEPLOY_ARCHIVE_MEMBERS(key_points);
std::vector<BBox> detections;
MMDEPLOY_ARCHIVE_MEMBERS(key_points, detections);
};


MMDEPLOY_DECLARE_CODEBASE(MMPose, mmpose);

} // namespace mmdeploy::mmpose
Expand Down
8 changes: 6 additions & 2 deletions csrc/mmdeploy/codebase/mmpose/simcc_label.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class SimCCLabelDecode : public MMPose {
}

auto& img_metas = _data["img_metas"];

if (img_metas.contains("bbox")) {
from_value(img_metas["bbox"], bbox_);
}
Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
Expand All @@ -73,6 +75,7 @@ class SimCCLabelDecode : public MMPose {
keypoints_data += 2;
scores_data += 1;
}
output.detections.push_back({{bbox_[0], bbox_[1], bbox_[2], bbox_[3]}, bbox_[4]});
return to_value(output);
}

Expand Down Expand Up @@ -102,7 +105,8 @@ class SimCCLabelDecode : public MMPose {
}
}

private:
private:
vector<float> bbox_{0, 0, 1, 1, 1};
bool flip_test_{false};
bool shift_heatmap_{false};
float simcc_split_ratio_{2.0};
Expand Down
101 changes: 101 additions & 0 deletions csrc/mmdeploy/codebase/mmpose/yolox_pose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include <cctype>
#include <opencv2/imgproc.hpp>
#include <iostream>
#include <fstream>

#include "mmdeploy/core/device.h"
#include "mmdeploy/core/registry.h"
#include "mmdeploy/core/serialization.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/core/value.h"
#include "mmdeploy/experimental/module_adapter.h"
#include "mmpose.h"
#include "opencv_utils.h"

namespace mmdeploy::mmpose {

using std::string;
using std::vector;

class YOLOXPose : public MMPose {
public:
explicit YOLOXPose(const Value& config) : MMPose(config) {
if (config.contains("params")) {
auto& params = config["params"];
if (params.contains("score_thr")) {
from_value(params["score_thr"], score_thr_);
}
}
}

Result<Value> operator()(const Value& _data, const Value& _prob) {
MMDEPLOY_DEBUG("preprocess_result: {}", _data);
MMDEPLOY_DEBUG("inference_result: {}", _prob);

Device cpu_device{"cpu"};
OUTCOME_TRY(auto dets,
MakeAvailableOnDevice(_prob["dets"].get<Tensor>(), cpu_device, stream()));
OUTCOME_TRY(auto keypoints,
MakeAvailableOnDevice(_prob["keypoints"].get<Tensor>(), cpu_device, stream()));
OUTCOME_TRY(stream().Wait());
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
(int)dets.data_type());
return Status(eNotSupported);
}
if (!(keypoints.shape().size() == 4 && keypoints.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `keypoints` tensor, shape: {}, dtype: {}", keypoints.shape(),
(int)keypoints.data_type());
return Status(eNotSupported);
}
auto& img_metas = _data["img_metas"];
vector<float> scale_factor;
if (img_metas.contains("scale_factor")) {
from_value(img_metas["scale_factor"], scale_factor);
} else {
scale_factor = {1.f, 1.f, 1.f, 1.f};
}
PoseDetectorOutput output;

float* keypoints_data = keypoints.data<float>();
float* dets_data = dets.data<float>();
int num_dets = dets.shape(1), num_pts = keypoints.shape(2);
float s = 0, x1=0, y1=0, x2=0, y2=0;

// fprintf(stdout, "num_dets= %d num_pts = %d\n", num_dets, num_pts);
for (int i = 0; i < dets.shape(0) * num_dets; i++){
x1 = (*(dets_data++)) / scale_factor[0];
y1 = (*(dets_data++)) / scale_factor[1];
x2 = (*(dets_data++)) / scale_factor[2];
y2 = (*(dets_data++)) / scale_factor[3];
s = *(dets_data++);
// fprintf(stdout, "box %.2f %.2f %.2f %.2f %.6f\n", i, x1,y1,x2,y2,s);

if (s <= score_thr_) {
keypoints_data += num_pts * 3;
continue;
}
output.detections.push_back({{x1, y1, x2, y2}, s});
for (int k = 0; k < num_pts; k++) {
x1 = (*(keypoints_data++)) / scale_factor[0];
y1 = (*(keypoints_data++)) / scale_factor[1];
s = *(keypoints_data++);
// fprintf(stdout, "point %d, index %d, %.2f %.2f %.6f\n", k, x1, y1, s);
output.key_points.push_back({{x1, y1}, s});
}
}
return to_value(output);
}

protected:
float score_thr_ = 0.001;

};

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMPose, YOLOXPose);

} // namespace mmdeploy::mmpose
4 changes: 4 additions & 0 deletions csrc/mmdeploy/device/cpu/cpu_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ Result<void> CpuPlatformImpl::Copy(Buffer src, Buffer dst, size_t size, size_t s
size_t dst_offset, Stream stream) {
auto src_ptr = src.GetNative();
auto dst_ptr = dst.GetNative();
if (size == 0) {
return success();
}

if (!src_ptr || !dst_ptr) {
return Status(eInvalidArgument);
}
Expand Down
16 changes: 9 additions & 7 deletions demo/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,25 @@ endif ()


function(add_example task folder name)
set(exec_name "${folder}.${name}")
if ((NOT task) OR (task IN_LIST MMDEPLOY_TASKS))
# Search for c/cpp sources
file(GLOB _SRCS ${folder}/${name}.c*)
add_executable(${name} ${_SRCS})
add_executable(${exec_name} ${_SRCS})
if (NOT (MSVC OR APPLE))
# Disable new dtags so that executables can run even without LD_LIBRARY_PATH set
target_link_libraries(${name} PRIVATE -Wl,--disable-new-dtags)
target_link_libraries(${exec_name} PRIVATE -Wl,--disable-new-dtags)
endif ()
if (MMDEPLOY_BUILD_SDK_MONOLITHIC)
target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS})
target_link_libraries(${exec_name} PRIVATE mmdeploy ${OpenCV_LIBS})
else ()
# Load MMDeploy modules
mmdeploy_load_static(${name} MMDeployStaticModules)
mmdeploy_load_dynamic(${name} MMDeployDynamicModules)
mmdeploy_load_static(${exec_name} MMDeployStaticModules)
mmdeploy_load_dynamic(${exec_name} MMDeployDynamicModules)
# Link to MMDeploy libraries
target_link_libraries(${name} PRIVATE MMDeployLibs ${OpenCV_LIBS})
target_link_libraries(${exec_name} PRIVATE MMDeployLibs ${OpenCV_LIBS})
endif ()
install(TARGETS ${name} RUNTIME DESTINATION bin)
install(TARGETS ${exec_name} RUNTIME DESTINATION bin)
endif ()
endfunction()

Expand All @@ -36,6 +37,7 @@ add_example(detector c batch_object_detection)
add_example(segmentor c image_segmentation)
add_example(restorer c image_restorer)
add_example(text_detector c ocr)
add_example(pose_detector c det_pose)
add_example(pose_detector c pose_detection)
add_example(rotated_detector c rotated_object_detection)
add_example(video_recognizer c video_recognition)
Expand Down
Loading

0 comments on commit d820513

Please sign in to comment.