Skip to content

Commit

Permalink
Code cleanup and consistency improvements to prepare PR#84
Browse files Browse the repository at this point in the history
  • Loading branch information
victorreijgwart committed Nov 21, 2024
1 parent 642680f commit 9269a3b
Show file tree
Hide file tree
Showing 15 changed files with 48 additions and 75 deletions.
4 changes: 3 additions & 1 deletion library/cpp/include/wavemap/core/data_structure/aabb.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <algorithm>
#include <limits>
#include <string>
#include <utility>

#include "wavemap/core/common.h"
#include "wavemap/core/utils/math/int_math.h"
Expand All @@ -25,7 +26,8 @@ struct AABB {
PointType max = PointType::Constant(kInitialMax);

AABB() = default;
AABB(PointT min, PointT max) : min(min), max(max) {}
AABB(const PointT& min, const PointT& max) : min(min), max(max) {}
AABB(PointT&& min, PointT&& max) : min(std::move(min)), max(std::move(max)) {}

void includePoint(const PointType& point) {
min = min.cwiseMin(point);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class ChunkedNdtreeNodePtr {
const NodeRef* operator->() const { return node_.operator->(); }

private:
// TODO(victorr): Benchmark version that uses chunk_=nullptr instead of
// std::optional to mark invalid states
std::optional<NodeRef> node_;
};

Expand All @@ -52,6 +50,7 @@ class ChunkedNdtreeNodeRef {
static constexpr int kNumChildren = NdtreeIndex<kDim>::kNumChildren;
using NodeDataType = typename ChunkType::DataType;

ChunkedNdtreeNodeRef() = delete;
ChunkedNdtreeNodeRef(ChunkType& chunk, IndexElement relative_node_depth,
LinearIndex level_traversal_distance);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,11 @@ ChunkedNdtree<NodeDataT, dim, chunk_height>::getNode(
const ChunkedNdtree::IndexType& index) {
NodePtrType node = &getRootNode();
const MortonIndex morton_code = convert::nodeIndexToMorton(index);
for (int node_height = max_height_; index.height < node_height;
for (int node_height = max_height_; node && index.height < node_height;
--node_height) {
const NdtreeIndexRelativeChild child_index =
NdtreeIndex<dim>::computeRelativeChildIndex(morton_code, node_height);
// Check if the child is allocated
NodePtrType child = node->getChild(child_index);
if (!child) {
return {};
}
node = child;
node = node->getChild(child_index);
}
return node;
}
Expand All @@ -101,16 +96,11 @@ ChunkedNdtree<NodeDataT, dim, chunk_height>::getNode(
const ChunkedNdtree::IndexType& index) const {
NodeConstPtrType node = &getRootNode();
const MortonIndex morton_code = convert::nodeIndexToMorton(index);
for (int node_height = max_height_; index.height < node_height;
for (int node_height = max_height_; node && index.height < node_height;
--node_height) {
const NdtreeIndexRelativeChild child_index =
NdtreeIndex<dim>::computeRelativeChildIndex(morton_code, node_height);
// Check if the child is allocated
NodeConstPtrType child = node->getChild(child_index);
if (!child) {
return {};
}
node = child;
node = node->getChild(child_index);
}
return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,11 @@ const typename Ndtree<NodeDataT, dim>::NodeType*
Ndtree<NodeDataT, dim>::getNode(const IndexType& index) const {
const NodeType* node = &root_node_;
const MortonIndex morton_code = convert::nodeIndexToMorton(index);
for (int node_height = max_height_; index.height < node_height;
for (int node_height = max_height_; node && index.height < node_height;
--node_height) {
const NdtreeIndexRelativeChild child_index =
NdtreeIndex<dim>::computeRelativeChildIndex(morton_code, node_height);
// Check if the child is allocated
const NodeType* child = node->getChild(child_index);
if (!child) {
return nullptr;
}
node = child;
node = node->getChild(child_index);
}
return node;
}
Expand Down
5 changes: 4 additions & 1 deletion library/cpp/include/wavemap/core/indexing/ndtree_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <limits>
#include <string>
#include <utility>
#include <vector>

#include "wavemap/core/common.h"
Expand Down Expand Up @@ -33,8 +34,10 @@ struct NdtreeIndex {
Position position = Position::Zero();

NdtreeIndex() = default;
NdtreeIndex(Element height, Position position)
NdtreeIndex(Element height, const Position& position)
: height(height), position(position) {}
NdtreeIndex(Element height, Position&& position)
: height(height), position(std::move(position)) {}

bool operator==(const NdtreeIndex& other) const {
return height == other.height && position == other.position;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HashedChunkedWaveletIntegrator : public ProjectiveIntegrator {
using OctreeType = HashedChunkedWaveletOctreeBlock::OctreeType;

const HashedChunkedWaveletOctree::Ptr occupancy_map_;
std::shared_ptr<ThreadPool> thread_pool_;
const std::shared_ptr<ThreadPool> thread_pool_;
std::shared_ptr<RangeImageIntersector> range_image_intersector_;

// Cache/pre-compute commonly used values
Expand Down Expand Up @@ -68,7 +68,6 @@ class HashedChunkedWaveletIntegrator : public ProjectiveIntegrator {
void updateNodeRecursive(OctreeType::NodeRefType node,
const OctreeIndex& node_index,
FloatingPoint& node_value,
OctreeType::ChunkType::BitRef node_has_child,
bool& block_needs_thresholding);
void updateLeavesBatch(const OctreeIndex& parent_index,
FloatingPoint& parent_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ class HashedWaveletIntegrator : public ProjectiveIntegrator {
: std::make_shared<ThreadPool>()) {}

private:
using BlockList = std::vector<HashedWaveletOctree::BlockIndex>;
using OctreeType = HashedWaveletOctreeBlock::OctreeType;

const HashedWaveletOctree::Ptr occupancy_map_;
std::shared_ptr<ThreadPool> thread_pool_;
const std::shared_ptr<ThreadPool> thread_pool_;
std::shared_ptr<RangeImageIntersector> range_image_intersector_;

// Cache/pre-compute commonly used values
Expand All @@ -51,13 +54,12 @@ class HashedWaveletIntegrator : public ProjectiveIntegrator {
std::pair<OctreeIndex, OctreeIndex> getFovMinMaxIndices(
const Point3D& sensor_origin) const;

using BlockList = std::vector<OctreeIndex>;
void recursiveTester(const OctreeIndex& node_index,
BlockList& update_job_list);

void updateMap() override;
void updateBlock(HashedWaveletOctree::Block& block,
const OctreeIndex& block_index);
const HashedWaveletOctree::BlockIndex& block_index);
};
} // namespace wavemap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ inline void HashedWaveletIntegrator::recursiveTester( // NOLINT
if (node_index.height == tree_height_) {
// Get the block
if (update_type == UpdateType::kPossiblyOccupied) {
update_job_list.emplace_back(node_index);
update_job_list.emplace_back(node_index.position);
return;
}
if (const auto* block = occupancy_map_->getBlock(node_index.position);
block) {
if (min_log_odds_ + kNoiseThreshold / 10.f <= block->getRootScale()) {
// Add the block to the job list
update_job_list.emplace_back(node_index);
update_job_list.emplace_back(node_index.position);
}
}
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class HashedChunkedWaveletOctreeBlock {
void recursiveThreshold(OctreeType::NodeRefType node,
Coefficients::Scale& node_scale_coefficient);
void recursivePrune(
HashedChunkedWaveletOctreeBlock::OctreeType::NodeRefType chunk);
HashedChunkedWaveletOctreeBlock::OctreeType::NodeRefType node);
};
} // namespace wavemap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,12 @@ inline FloatingPoint HashedChunkedWaveletOctreeBlock::getCellValue(
const MortonIndex morton_code = convert::nodeIndexToMorton(index);
OctreeType::NodeConstPtrType node = &ndtree_.getRootNode();
FloatingPoint value = root_scale_coefficient_;
for (int parent_height = tree_height_; index.height < parent_height;
for (int parent_height = tree_height_; node && index.height < parent_height;
--parent_height) {
const NdtreeIndexRelativeChild child_index =
OctreeIndex::computeRelativeChildIndex(morton_code, parent_height);
value = Transform::backwardSingleChild({value, node->data()}, child_index);
auto child = node->getChild(child_index);
if (!child) {
break;
}
node = child;
node = node->getChild(child_index);
}
return value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ inline FloatingPoint HashedWaveletOctreeBlock::getCellValue(
const MortonIndex morton_code = convert::nodeIndexToMorton(index);
OctreeType::NodeConstPtrType node = &ndtree_.getRootNode();
FloatingPoint value = root_scale_coefficient_;
for (int parent_height = tree_height_; index.height < parent_height;
for (int parent_height = tree_height_; node && index.height < parent_height;
--parent_height) {
const NdtreeIndexRelativeChild child_index =
OctreeIndex::computeRelativeChildIndex(morton_code, parent_height);
value = Transform::backwardSingleChild({value, node->data()}, child_index);
if (!node->hasChild(child_index)) {
break;
}
node = node->getChild(child_index);
}
return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct StampedPose {
Transformation3D pose{};

StampedPose() = default;
StampedPose(TimeAbsolute stamp, Transformation3D pose)
StampedPose(TimeAbsolute stamp, const Transformation3D& pose)
: stamp(stamp), pose(pose) {}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,13 @@ void HashedChunkedWaveletIntegrator::updateBlock(
bool block_needs_thresholding = block.getNeedsThresholding();
const OctreeIndex root_node_index{tree_height_, block_index};
updateNodeRecursive(block.getRootNode(), root_node_index,
block.getRootScale(),
block.getRootChunk().nodeHasAtLeastOneChild(0u),
block_needs_thresholding);
block.getRootScale(), block_needs_thresholding);
block.setNeedsThresholding(block_needs_thresholding);
}

void HashedChunkedWaveletIntegrator::updateNodeRecursive( // NOLINT
HashedChunkedWaveletIntegrator::OctreeType::NodeRefType node,
const OctreeIndex& node_index, FloatingPoint& node_value,
HashedChunkedWaveletIntegrator::OctreeType::ChunkType::BitRef
node_has_child,
bool& block_needs_thresholding) {
// Decompress child values
auto& node_details = node.data();
Expand Down Expand Up @@ -146,21 +142,16 @@ void HashedChunkedWaveletIntegrator::updateNodeRecursive( // NOLINT
// Since the approximation error would still be too big, refine
auto child_node = node.getOrAllocateChild(relative_child_idx);
auto& child_details = child_node.data();
auto child_has_child = child_node.hasAtLeastOneChild();

// If we're at the leaf level, directly compute the update
if (child_index.height <= termination_height_ + 1) {
updateLeavesBatch(child_index, child_value, child_details);
} else {
// Otherwise, recurse
DCHECK_GE(child_index.height, 0);
updateNodeRecursive(child_node, child_index, child_value, child_has_child,
updateNodeRecursive(child_node, child_index, child_value,
block_needs_thresholding);
}

if (child_has_child || data::is_nonzero(child_details)) {
node_has_child = true;
}
}

// Compress
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ void HashedWaveletIntegrator::updateMap() {

// Make sure the to-be-updated blocks are allocated
for (const auto& block_index : blocks_to_update) {
occupancy_map_->getOrAllocateBlock(block_index.position);
occupancy_map_->getOrAllocateBlock(block_index);
}

// Update it with the threadpool
for (const auto& block_index : blocks_to_update) {
thread_pool_->add_task([this, block_index]() {
if (auto* block = occupancy_map_->getBlock(block_index.position)) {
if (auto* block = occupancy_map_->getBlock(block_index); block) {
updateBlock(*block, block_index);
}
});
Expand All @@ -50,10 +50,9 @@ void HashedWaveletIntegrator::updateMap() {
std::pair<OctreeIndex, OctreeIndex>
HashedWaveletIntegrator::getFovMinMaxIndices(
const Point3D& sensor_origin) const {
const IndexElement height =
1 + std::max(static_cast<IndexElement>(std::ceil(
std::log2(config_.max_range / min_cell_width_))),
tree_height_);
const int height = 1 + std::max(static_cast<int>(std::ceil(std::log2(
config_.max_range / min_cell_width_))),
tree_height_);
const OctreeIndex fov_min_idx = convert::indexAndHeightToNodeIndex<3>(
convert::pointToFloorIndex<3>(
sensor_origin - Vector3D::Constant(config_.max_range),
Expand All @@ -69,25 +68,28 @@ HashedWaveletIntegrator::getFovMinMaxIndices(
return {fov_min_idx, fov_max_idx};
}

void HashedWaveletIntegrator::updateBlock(HashedWaveletOctree::Block& block,
const OctreeIndex& block_index) {
void HashedWaveletIntegrator::updateBlock(
HashedWaveletOctree::Block& block,
const HashedWaveletOctree::BlockIndex& block_index) {
ProfilerZoneScoped;
HashedWaveletOctreeBlock::OctreeType::NodeRefType root_node =
block.getRootNode();
HashedWaveletOctreeBlock::Coefficients::Scale& root_node_scale =
block.getRootScale();
block.setNeedsPruning();
block.setLastUpdatedStamp();

struct StackElement {
HashedWaveletOctreeBlock::OctreeType::NodeRefType parent_node;
OctreeType::NodeRefType parent_node;
const OctreeIndex parent_node_index;
NdtreeIndexRelativeChild next_child_idx;
HashedWaveletOctreeBlock::Coefficients::CoefficientsArray
child_scale_coefficients;
};
std::stack<StackElement> stack;
stack.emplace(StackElement{root_node, block_index, 0,

OctreeType::NodeRefType root_node = block.getRootNode();
HashedWaveletOctreeBlock::Coefficients::Scale& root_node_scale =
block.getRootScale();
stack.emplace(StackElement{root_node,
{tree_height_, block_index},
0,
HashedWaveletOctreeBlock::Transform::backward(
{root_node_scale, root_node.data()})});

Expand Down Expand Up @@ -117,8 +119,7 @@ void HashedWaveletIntegrator::updateBlock(HashedWaveletOctree::Block& block,
DCHECK_GE(current_child_idx, 0);
DCHECK_LT(current_child_idx, OctreeIndex::kNumChildren);

HashedWaveletOctreeBlock::OctreeType::NodeRefType parent_node =
stack.top().parent_node;
OctreeType::NodeRefType parent_node = stack.top().parent_node;
FloatingPoint& node_value =
stack.top().child_scale_coefficients[current_child_idx];
const OctreeIndex node_index =
Expand Down Expand Up @@ -171,7 +172,7 @@ void HashedWaveletIntegrator::updateBlock(HashedWaveletOctree::Block& block,
projection_model_->cartesianToSensorZ(C_node_center);
const FloatingPoint bounding_sphere_radius =
kUnitCubeHalfDiagonal * node_width;
HashedWaveletOctreeBlock::OctreeType::NodePtrType node =
OctreeType::NodePtrType node =
parent_node.getChild(node_index.computeRelativeChildIndex());
if (measurement_model_->computeWorstCaseApproximationError(
update_type, d_C_cell, bounding_sphere_radius) <
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ void HashedChunkedWaveletOctreeBlock::recursivePrune( // NOLINT
}
}
}
if (!has_at_least_one_child) {
node.hasAtLeastOneChild() = false;
}
node.hasAtLeastOneChild() = has_at_least_one_child;
}
} // namespace wavemap

0 comments on commit 9269a3b

Please sign in to comment.