From fea5914d14fccea42c41e2f78de881d47a83611b Mon Sep 17 00:00:00 2001 From: Victor Reijgwart Date: Thu, 24 Oct 2024 19:10:44 +0200 Subject: [PATCH] Draft support for multi-threading --- library/python/src/raycast.cc | 54 ++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/library/python/src/raycast.cc b/library/python/src/raycast.cc index f429f1c69..e5539d85e 100644 --- a/library/python/src/raycast.cc +++ b/library/python/src/raycast.cc @@ -7,14 +7,16 @@ #include #include #include +#include #include "wavemap/core/utils/iterate/ray_iterator.h" using namespace nb::literals; // NOLINT namespace wavemap { -FloatingPoint raycast(const HashedWaveletOctree& map, Point3D start_point, - Point3D end_point, FloatingPoint threshold) { +FloatingPoint raycast(const HashedWaveletOctree& map, + const Point3D& start_point, const Point3D& end_point, + FloatingPoint threshold) { const FloatingPoint mcw = map.getMinCellWidth(); const Ray ray(start_point, end_point, mcw); for (const Index3D& ray_voxel_index : ray) { @@ -29,8 +31,8 @@ FloatingPoint raycast(const HashedWaveletOctree& map, Point3D start_point, FloatingPoint raycast_fast( QueryAccelerator& query_accelerator, - Point3D start_point, Point3D end_point, FloatingPoint threshold, - FloatingPoint min_cell_width) { + const Point3D& start_point, const Point3D& end_point, + FloatingPoint threshold, FloatingPoint min_cell_width) { const Ray ray(start_point, end_point, min_cell_width); for (const Index3D& ray_voxel_index : ray) { if (query_accelerator.getCellValue(ray_voxel_index) > threshold) { @@ -72,25 +74,45 @@ void add_raycast_bindings(nb::module_& m) { m.def( "get_depth", - [](const HashedWaveletOctree& map, Transformation3D pose, - PinholeCameraProjectorConfig cam_cfg, FloatingPoint threshold, + [](const HashedWaveletOctree& map, const Transformation3D& pose, + const PinholeCameraProjectorConfig& cam_cfg, FloatingPoint threshold, FloatingPoint max_range) { + // NOTE: This way of parallelizing it is not very efficient, as it + // creates a very large number of jobs (1 per pixel), but already + // leads to a nice speedup. The next step to improve it would + // probably be to split the image into tiles and spawn 1 job per + // tile. The tile size should be such that there are enough tiles + // to distribute the work evenly across all cores even if some + // tiles take much shorter than others, while still being few + // enough to minimize the overhead of dispatching jobs and create + // local QueryAccelerator instances for each job. Maybe 10x as + // many tiles as there are workers? + ThreadPool thread_pool; // By default, the pool will spawn as many + // workers as the system's reported + // std::thread::hardware_concurrency(). Image depth_image(cam_cfg.width, cam_cfg.height); - QueryAccelerator query_accelerator(map); - const FloatingPoint mcw = map.getMinCellWidth(); + const FloatingPoint min_cell_width = map.getMinCellWidth(); const PinholeCameraProjector projection_model(cam_cfg); - auto start_point = pose.getPosition(); + const Point3D& start_point = pose.getPosition(); for (const Index2D& index : Grid<2>(Index2D::Zero(), depth_image.getDimensions() - Index2D::Ones())) { - const Vector2D image_xy = projection_model.indexToImage(index); - const Point3D C_point = - projection_model.sensorToCartesian({image_xy, max_range}); - const Point3D end_point = pose * C_point; - FloatingPoint depth = raycast_fast(query_accelerator, start_point, - end_point, threshold, mcw); - depth_image.at(index) = depth; + FloatingPoint& depth_pixel = depth_image.at(index); + thread_pool.add_task([&map, &projection_model, &pose, &start_point, + &depth_pixel, index, max_range, threshold, + min_cell_width]() { + QueryAccelerator query_accelerator(map); + const Vector2D image_xy = projection_model.indexToImage(index); + const Point3D C_point = + projection_model.sensorToCartesian({image_xy, max_range}); + const Point3D end_point = pose * C_point; + FloatingPoint depth = + raycast_fast(query_accelerator, start_point, end_point, + threshold, min_cell_width); + depth_pixel = depth; + }); } + thread_pool.wait_all(); return depth_image.getData().transpose().eval(); }, "Extract depth from octree map at using given camera pose and "