Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support range search, fix #245 #248

Merged
merged 1 commit into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/impl/MilvusClientImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "TypeUtils.h"
#include "common.pb.h"
#include "milvus.grpc.pb.h"
#include "milvus.pb.h"
#include "schema.pb.h"

Expand Down Expand Up @@ -741,7 +740,13 @@ MilvusClientImpl::Search(const SearchArguments& arguments, SearchResults& result

kv_pair = rpc_request.add_search_params();
kv_pair->set_key(milvus::KeyParams());
kv_pair->set_value(arguments.ExtraParams());
// merge extra params with range search
auto json = nlohmann::json::parse(arguments.ExtraParams());
if (arguments.RangeSearch()) {
json["range_filter"] = arguments.RangeFilter();
json["radius"] = arguments.Radius();
}
kv_pair->set_value(json.dump());

rpc_request.set_travel_timestamp(arguments.TravelTimestamp());
rpc_request.set_guarantee_timestamp(arguments.GuaranteeTimestamp());
Expand Down
62 changes: 58 additions & 4 deletions src/impl/types/SearchArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "milvus/types/SearchArguments.h"

#include <nlohmann/json.hpp>
#include <utility>

namespace milvus {
namespace {
Expand All @@ -28,7 +29,7 @@ struct Validation {
bool required;

Status
Validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) const {
Validate(const SearchArguments&, std::unordered_map<std::string, int64_t> params) const {
auto it = params.find(param);
if (it != params.end()) {
auto value = it->second;
Expand All @@ -43,7 +44,7 @@ struct Validation {
};

Status
validate(const SearchArguments& data, std::unordered_map<std::string, int64_t> params) {
validate(const SearchArguments& data, const std::unordered_map<std::string, int64_t>& params) {
auto status = Status::OK();
auto validations = {
Validation{"nprobe", 1, 65536, false},
Expand Down Expand Up @@ -128,7 +129,7 @@ SearchArguments::TargetVectors() const {

Status
SearchArguments::AddTargetVector(std::string field_name, const std::string& vector) {
return AddTargetVector(field_name, std::string{vector});
return AddTargetVector(std::move(field_name), std::string{vector});
}

Status
Expand Down Expand Up @@ -223,6 +224,20 @@ SearchArguments::TopK() const {
return topk_;
}

int64_t
SearchArguments::Nprobe() const {
if (extra_params_.find("nprobe") != extra_params_.end()) {
return extra_params_.at("nprobe");
}
return 1;
}

Status
SearchArguments::SetNprobe(int64_t nprobe) {
extra_params_["nprobe"] = nprobe;
return Status::OK();
}

Status
SearchArguments::SetRoundDecimal(int round_decimal) {
round_decimal_ = round_decimal;
Expand All @@ -236,6 +251,12 @@ SearchArguments::RoundDecimal() const {

Status
SearchArguments::SetMetricType(::milvus::MetricType metric_type) {
if (((metric_type == MetricType::IP && metric_type_ == MetricType::L2) ||
(metric_type == MetricType::L2 && metric_type_ == MetricType::IP)) &&
range_search_) {
// switch radius and range_filter
std::swap(radius_, range_filter_);
}
metric_type_ = metric_type;
return Status::OK();
}
Expand All @@ -251,7 +272,7 @@ SearchArguments::AddExtraParam(std::string key, int64_t value) {
return Status::OK();
}

const std::string
std::string
SearchArguments::ExtraParams() const {
return ::nlohmann::json(extra_params_).dump();
}
Expand All @@ -261,4 +282,37 @@ SearchArguments::Validate() const {
return validate(*this, extra_params_);
}

float
SearchArguments::Radius() const {
return radius_;
}

float
SearchArguments::RangeFilter() const {
return range_filter_;
}

Status
SearchArguments::SetRange(float from, float to) {
auto low = std::min(from, to);
auto high = std::max(from, to);
if (metric_type_ == MetricType::IP) {
radius_ = low;
range_filter_ = high;
range_search_ = true;
} else if (metric_type_ == MetricType::L2) {
radius_ = high;
range_filter_ = low;
range_search_ = true;
} else {
return {StatusCode::INVALID_AGUMENT, "Metric type is not supported"};
}
return Status::OK();
}

bool
SearchArguments::RangeSearch() const {
return range_search_;
}

} // namespace milvus
46 changes: 45 additions & 1 deletion src/include/milvus/types/SearchArguments.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ class SearchArguments {
int64_t
TopK() const;

/**
* @brief Get nprobe
*/
int64_t
Nprobe() const;

/**
* @brief Set nprobe
*/
Status
SetNprobe(int64_t nlist);

/**
* @brief Specifies the decimal place of the returned results.
*/
Expand Down Expand Up @@ -197,7 +209,7 @@ class SearchArguments {
/**
* @brief Get extra param
*/
const std::string
std::string
ExtraParams() const;

/**
Expand All @@ -207,6 +219,35 @@ class SearchArguments {
Status
Validate() const;

/**
* @brief Get range radius
* @return
*/
float
Radius() const;

/**
* @brief Get range filter
* @return
*/
float
RangeFilter() const;

/**
* @brief Set range radius
* @param from range radius from
* @param to range radius to
*/
Status
SetRange(float from, float to);

/**
* @brief Get if do range search
* @return
*/
bool
RangeSearch() const;

private:
std::string collection_name_;
std::set<std::string> partition_names_;
Expand All @@ -225,6 +266,9 @@ class SearchArguments {
int64_t topk_{1};
int round_decimal_{-1};

float radius_;
float range_filter_;
bool range_search_{false};
::milvus::MetricType metric_type_{::milvus::MetricType::L2};
};

Expand Down
70 changes: 70 additions & 0 deletions test/st/TestSearch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,76 @@ TEST_F(MilvusServerTestSearch, SearchWithoutIndex) {
dropCollection();
}

TEST_F(MilvusServerTestSearch, RangeSearch) {
std::vector<milvus::FieldDataPtr> fields{
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13, 14, 15, 16, 17, 18}),
std::make_shared<milvus::VarCharFieldData>(
"name", std::vector<std::string>{"Tom", "Jerry", "Lily", "Foo", "Bar", "Jake", "Jonathon"}),
std::make_shared<milvus::FloatVecFieldData>("face", std::vector<std::vector<float>>{
std::vector<float>{0.1f, 0.2f, 0.3f, 0.4f},
std::vector<float>{0.2f, 0.3f, 0.4f, 0.5f},
std::vector<float>{0.3f, 0.4f, 0.5f, 0.6f},
std::vector<float>{0.4f, 0.5f, 0.6f, 0.7f},
std::vector<float>{0.5f, 0.6f, 0.7f, 0.8f},
std::vector<float>{0.6f, 0.7f, 0.8f, 0.9f},
std::vector<float>{0.7f, 0.8f, 0.9f, 1.0f},
})};

createCollectionAndPartitions(true);
auto dml_results = insertRecords(fields);
loadCollection();

milvus::SearchArguments arguments{};
arguments.SetCollectionName(collection_name);
arguments.AddPartitionName(partition_name);
arguments.SetRange(0.3, 1.0);
arguments.SetTopK(10);
arguments.AddOutputField("age");
arguments.AddOutputField("name");
arguments.AddTargetVector("face", std::vector<float>{0.f, 0.f, 0.f, 0.f});
arguments.AddTargetVector("face", std::vector<float>{1.f, 1.f, 1.f, 1.f});
milvus::SearchResults search_results{};
auto status = client_->Search(arguments, search_results);
EXPECT_EQ(status.Message(), "OK");
EXPECT_TRUE(status.IsOk());

const auto& results = search_results.Results();
EXPECT_EQ(results.size(), 2);

// validate results
auto validateScores = [&results](int firstRet, int secondRet) {
// check score should between range
for (const auto& result : results) {
for (const auto& score : result.Scores()) {
EXPECT_GE(score, 0.3);
EXPECT_LE(score, 1.0);
}
}
EXPECT_EQ(results.at(0).Ids().IntIDArray().size(), firstRet);
EXPECT_EQ(results.at(1).Ids().IntIDArray().size(), secondRet);
};

// valid score in range is 3, 2
validateScores(3, 2);

// add fields, then search again, should be 6 and 4
insertRecords(fields);
loadCollection();
status = client_->Search(arguments, search_results);
EXPECT_TRUE(status.IsOk());
validateScores(6, 4);

// add fields twice, and now it should be 12, 8, as limit is 10, then should be 10, 8
insertRecords(fields);
insertRecords(fields);
loadCollection();
status = client_->Search(arguments, search_results);
EXPECT_TRUE(status.IsOk());
validateScores(10, 8);

dropCollection();
}

TEST_F(MilvusServerTestSearch, SearchWithStringFilter) {
std::vector<milvus::FieldDataPtr> fields{
std::make_shared<milvus::Int16FieldData>("age", std::vector<int16_t>{12, 13}),
Expand Down
21 changes: 21 additions & 0 deletions test/ut/TestSearchArguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,24 @@ TEST_F(SearchArgumentsTest, ValidateTesting) {
EXPECT_TRUE(status.IsOk());
}
}

TEST_F(SearchArgumentsTest, Nprobe) {
milvus::SearchArguments arguments;
arguments.AddExtraParam("nprobe", 10);
EXPECT_EQ(10, arguments.Nprobe());

arguments.SetNprobe(20);
EXPECT_EQ(20, arguments.Nprobe());
}

TEST_F(SearchArgumentsTest, RangeSearchParams) {
milvus::SearchArguments arguments;
arguments.SetMetricType(milvus::MetricType::IP);
arguments.SetRange(0.1, 0.2);
EXPECT_NEAR(0.1, arguments.Radius(), 0.00001);
EXPECT_NEAR(0.2, arguments.RangeFilter(), 0.00001);

arguments.SetMetricType(milvus::MetricType::L2);
EXPECT_NEAR(0.2, arguments.Radius(), 0.00001);
EXPECT_NEAR(0.1, arguments.RangeFilter(), 0.00001);
}