From 8f330fb2c7a7e2c13a41336fb3c729135a0d81c8 Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Thu, 24 Jun 2021 08:52:54 +0100 Subject: [PATCH] Make `Agent` copyable and moveable This allows implementing C++ iterator interface on top of the Agent more efficiently. C++ iterators must be copyable, and the only way to copy one previously was to repeat the query until the index. --- configure.ac | 7 +++ include/marisa/agent.h | 21 ++++++--- include/marisa/scoped-array.h | 10 +++++ include/marisa/scoped-ptr.h | 10 +++++ lib/marisa/agent.cc | 66 +++++++++++++++++++++++------ lib/marisa/grimoire/trie/state.h | 12 ++++-- lib/marisa/grimoire/vector/vector.h | 50 ++++++++++++++++++---- tests/marisa-test.cc | 39 +++++++++++++++++ 8 files changed, 185 insertions(+), 30 deletions(-) diff --git a/configure.ac b/configure.ac index 118a677..c7adec0 100644 --- a/configure.ac +++ b/configure.ac @@ -13,6 +13,9 @@ AC_PROG_INSTALL AC_CONFIG_MACRO_DIR([m4]) +# Sanitizers +AC_ARG_ENABLE([asan], AS_HELP_STRING([--enable-asan], [Enable address sanitizer])) + # Macros for SSE availability check. AC_DEFUN([MARISA_ENABLE_SSE2], [AC_EGREP_CPP([yes], [ @@ -241,6 +244,10 @@ elif test "x${enable_sse2}" != "xno"; then CXXFLAGS="$CXXFLAGS -DMARISA_USE_SSE2 -msse2" fi +AS_IF([test "x$enable_asan" = "xyes"], [ + CXXFLAGS="$CXXFLAGS -fsanitize=address" +]) + AC_CONFIG_FILES([Makefile marisa.pc include/Makefile diff --git a/include/marisa/agent.h b/include/marisa/agent.h index b549d36..117aed8 100644 --- a/include/marisa/agent.h +++ b/include/marisa/agent.h @@ -22,6 +22,14 @@ class Agent { Agent(); ~Agent(); + Agent(const Agent &other); + Agent &operator=(const Agent &other); + +#if __cplusplus >= 201103L + Agent(Agent &&other) noexcept; + Agent &operator=(Agent &&other) noexcept; +#endif + const Query &query() const { return query_; } @@ -37,6 +45,9 @@ class Agent { void set_query(const char *str); void set_query(const char *ptr, std::size_t length); void set_query(std::size_t key_id); + void set_query(const Query &query) { + query_ = query; + } const grimoire::trie::State &state() const { return *state_; @@ -65,7 +76,7 @@ class Agent { } bool has_state() const { - return state_.get() != NULL; + return state_ != NULL; } void init_state(); @@ -75,11 +86,11 @@ class Agent { private: Query query_; Key key_; - scoped_ptr state_; - // Disallows copy and assignment. - Agent(const Agent &); - Agent &operator=(const Agent &); + // Cannot be `scoped_ptr` because `State` is forward-declared. + grimoire::trie::State *state_; + + void clear_state(); }; } // namespace marisa diff --git a/include/marisa/scoped-array.h b/include/marisa/scoped-array.h index 34cefa4..34cfaeb 100644 --- a/include/marisa/scoped-array.h +++ b/include/marisa/scoped-array.h @@ -11,6 +11,16 @@ class scoped_array { scoped_array() : array_(NULL) {} explicit scoped_array(T *array) : array_(array) {} +#if __cplusplus >= 201103L + scoped_array(scoped_array &&other) noexcept : array_(other.array_) { + other.array_ = NULL; + } + scoped_array &operator=(scoped_array &&other) noexcept { + other.array_ = NULL; + return *this; + } +#endif + ~scoped_array() { delete [] array_; } diff --git a/include/marisa/scoped-ptr.h b/include/marisa/scoped-ptr.h index abf48d8..a4035f4 100644 --- a/include/marisa/scoped-ptr.h +++ b/include/marisa/scoped-ptr.h @@ -15,6 +15,16 @@ class scoped_ptr { delete ptr_; } +#if __cplusplus >= 201103L + scoped_ptr(scoped_ptr &&other) noexcept : ptr_(other.ptr_) { + other.ptr_ = NULL; + } + scoped_ptr &operator=(scoped_ptr &&other) noexcept { + other.ptr_ = NULL; + return *this; + } +#endif + void reset(T *ptr = NULL) { MARISA_DEBUG_IF((ptr != NULL) && (ptr == ptr_), MARISA_RESET_ERROR); scoped_ptr(ptr).swap(*this); diff --git a/lib/marisa/agent.cc b/lib/marisa/agent.cc index 7fa7cb1..51baa5b 100644 --- a/lib/marisa/agent.cc +++ b/lib/marisa/agent.cc @@ -5,47 +5,89 @@ namespace marisa { -Agent::Agent() : query_(), key_(), state_() {} +Agent::Agent() : query_(), key_(), state_(NULL) {} -Agent::~Agent() {} +Agent::~Agent() { + delete state_; +} + +Agent::Agent(const Agent &other) + : query_(other.query_), + key_(other.key_), + state_(other.has_state() ? new (std::nothrow) grimoire::trie::State(other.state()) : NULL) {} + +Agent &Agent::operator=(const Agent &other) { + query_ = other.query_; + key_ = other.key_; + delete state_; + if (other.has_state()) { + state_ = new (std::nothrow) grimoire::trie::State(other.state()); + } else { + state_ = NULL; + } + return *this; +} + +#if __cplusplus >= 201103L +Agent::Agent(Agent &&other) noexcept + : query_(other.query_), key_(other.key_), state_(other.state_) { + other.state_ = NULL; +} + +Agent &Agent::operator=(Agent &&other) noexcept { + query_ = other.query_; + key_ = other.key_; + delete state_; + state_ = other.state_; + other.state_ = NULL; + return *this; +} +#endif void Agent::set_query(const char *str) { MARISA_THROW_IF(str == NULL, MARISA_NULL_ERROR); - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_str(str); } void Agent::set_query(const char *ptr, std::size_t length) { MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_str(ptr, length); } void Agent::set_query(std::size_t key_id) { - if (state_.get() != NULL) { - state_->reset(); + if (state_ != NULL) { + clear_state(); } query_.set_id(key_id); } void Agent::init_state() { - MARISA_THROW_IF(state_.get() != NULL, MARISA_STATE_ERROR); - state_.reset(new (std::nothrow) grimoire::State); - MARISA_THROW_IF(state_.get() == NULL, MARISA_MEMORY_ERROR); + MARISA_THROW_IF(state_ != NULL, MARISA_STATE_ERROR); + delete state_; + state_ = new (std::nothrow) grimoire::State; + MARISA_THROW_IF(state_ == NULL, MARISA_MEMORY_ERROR); } void Agent::clear() { Agent().swap(*this); } + +void Agent::clear_state() { + delete state_; + state_ = nullptr; +} + void Agent::swap(Agent &rhs) { query_.swap(rhs.query_); key_.swap(rhs.key_); - state_.swap(rhs.state_); + marisa::swap(state_, rhs.state_); } } // namespace marisa diff --git a/lib/marisa/grimoire/trie/state.h b/lib/marisa/grimoire/trie/state.h index df605a6..07bda52 100644 --- a/lib/marisa/grimoire/trie/state.h +++ b/lib/marisa/grimoire/trie/state.h @@ -24,6 +24,14 @@ class State { : key_buf_(), history_(), node_id_(0), query_pos_(0), history_pos_(0), status_code_(MARISA_READY_TO_ALL) {} + State(const State &) = default; + State &operator=(const State &) = default; + +#if __cplusplus >= 201103L + State(State &&) noexcept = default; + State &operator=(State &&) noexcept = default; +#endif + void set_node_id(std::size_t node_id) { MARISA_DEBUG_IF(node_id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); node_id_ = (UInt32)node_id; @@ -104,10 +112,6 @@ class State { UInt32 query_pos_; UInt32 history_pos_; StatusCode status_code_; - - // Disallows copy and assignment. - State(const State &); - State &operator=(const State &); }; } // namespace trie diff --git a/lib/marisa/grimoire/vector/vector.h b/lib/marisa/grimoire/vector/vector.h index 2bfccdb..209bc70 100644 --- a/lib/marisa/grimoire/vector/vector.h +++ b/lib/marisa/grimoire/vector/vector.h @@ -23,6 +23,38 @@ class Vector { } } + Vector(const Vector &other) + : buf_(), objs_(NULL), const_objs_(NULL), + size_(0), capacity_(0), fixed_(other.fixed_) { + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + } + + Vector &operator=(const Vector &other) { + clear(); + fixed_ = other.fixed_; + if (other.buf_.get() == NULL) { + objs_ = other.objs_; + const_objs_ = other.const_objs_; + size_ = other.size_; + capacity_ = other.capacity_; + } else { + copy(other.const_objs_, other.size_, other.capacity_); + } + return *this; + } + +#if __cplusplus >= 201103L + Vector(Vector &&) noexcept = default; + Vector &operator=(Vector &&) noexcept = default; +#endif + void map(Mapper &mapper) { Vector temp; temp.map_(mapper); @@ -225,14 +257,17 @@ class Vector { // realloc() assumes that T's placement new does not throw an exception. void realloc(std::size_t new_capacity) { MARISA_DEBUG_IF(new_capacity > max_size(), MARISA_SIZE_ERROR); + copy(objs_, size_, new_capacity); + } - scoped_array new_buf( - new (std::nothrow) char[sizeof(T) * new_capacity]); + // copy() assumes that T's placement new does not throw an exception. + void copy(const T *src, std::size_t src_size, std::size_t capacity) { + scoped_array new_buf(new (std::nothrow) char[sizeof(T) * capacity]); MARISA_DEBUG_IF(new_buf.get() == NULL, MARISA_MEMORY_ERROR); T *new_objs = reinterpret_cast(new_buf.get()); - for (std::size_t i = 0; i < size_; ++i) { - new (&new_objs[i]) T(objs_[i]); + for (std::size_t i = 0; i < src_size; ++i) { + new (&new_objs[i]) T(src[i]); } for (std::size_t i = 0; i < size_; ++i) { objs_[i].~T(); @@ -241,12 +276,9 @@ class Vector { buf_.swap(new_buf); objs_ = new_objs; const_objs_ = new_objs; - capacity_ = new_capacity; + size_ = src_size; + capacity_ = capacity; } - - // Disallows copy and assignment. - Vector(const Vector &); - Vector &operator=(const Vector &); }; } // namespace vector diff --git a/tests/marisa-test.cc b/tests/marisa-test.cc index 36e4258..4ab4e7b 100644 --- a/tests/marisa-test.cc +++ b/tests/marisa-test.cc @@ -1,7 +1,10 @@ +#include #include #include #include #include +#include +#include #include @@ -258,6 +261,41 @@ void TestPredictiveSearch(const marisa::Trie &trie, } } +void TestPredictiveSearchAgentCopy(const marisa::Trie &trie, + const marisa::Keyset &keyset) { + marisa::Agent agent; + for (std::size_t i = 0; i < keyset.size(); ++i) { + agent.set_query(keyset[i].ptr(), keyset[i].length()); + ASSERT(trie.predictive_search(agent)); + ASSERT(agent.key().id() == keyset[i].id()); + + std::vector agent_copies; + std::vector ids; + while (trie.predictive_search(agent)) { + ASSERT(agent.key().id() > keyset[i].id()); + ids.push_back(agent.key().id()); + + // Tests copy constructor. + agent_copies.push_back(agent); + } + + for (std::size_t i = 0; i < agent_copies.size(); ++i) { + marisa::Agent agent_copy; + + // Tests copy assignment. + agent_copy = agent_copies[i]; + + ASSERT(agent_copy.key().id() == ids[i]); + if (i + 1 < agent_copies.size()) { + ASSERT(trie.predictive_search(agent_copy)); + ASSERT(agent_copy.key().id() == ids[i + 1]); + } else { + ASSERT(!trie.predictive_search(agent_copy)); + } + } + } +} + void TestTrie(int num_tries, marisa::TailMode tail_mode, marisa::NodeOrder node_order, marisa::Keyset &keyset) { for (std::size_t i = 0; i < keyset.size(); ++i) { @@ -276,6 +314,7 @@ void TestTrie(int num_tries, marisa::TailMode tail_mode, TestLookup(trie, keyset); TestCommonPrefixSearch(trie, keyset); TestPredictiveSearch(trie, keyset); + TestPredictiveSearchAgentCopy(trie, keyset); trie.save("marisa-test.dat");