diff --git a/be/src/common/config.h b/be/src/common/config.h index 44532ee947994..06687794ea4d1 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -1475,7 +1475,7 @@ CONF_mBool(enable_vector_index_block_cache, "true"); CONF_mInt32(config_vector_index_build_concurrency, "8"); // default not to build the empty index -CONF_mInt32(config_vector_index_default_build_threshold, "0"); +CONF_mInt32(config_vector_index_default_build_threshold, "100"); // When upgrade thrift to 0.20.0, the MaxMessageSize member defines the maximum size of a (received) message, in bytes. // The default value is represented by a constant named DEFAULT_MAX_MESSAGE_SIZE, whose value is 100 * 1024 * 1024 bytes. diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.cpp b/be/src/exec/pipeline/scan/olap_chunk_source.cpp index d180bae61525e..c65df7ed430e5 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.cpp +++ b/be/src/exec/pipeline/scan/olap_chunk_source.cpp @@ -82,6 +82,7 @@ Status OlapChunkSource::prepare(RuntimeState* state) { if (_use_vector_index) { _use_ivfpq = vector_search_options.use_ivfpq; _vector_distance_column_name = vector_search_options.vector_distance_column_name; + _vector_slot_id = vector_search_options.vector_slot_id; _params.vector_search_option = std::make_shared(); } const TupleDescriptor* tuple_desc = state->desc_tbl().get_tuple_descriptor(thrift_olap_scan_node.tuple_id); @@ -320,12 +321,10 @@ Status OlapChunkSource::_init_scanner_columns(std::vector& scanner_col for (auto slot : *_slots) { DCHECK(slot->is_materialized()); int32_t index; - if (_use_vector_index && !_use_ivfpq) { - index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); - if (slot->col_name() == _vector_distance_column_name) { - _params.vector_search_option->vector_column_id = index; - _params.vector_search_option->vector_slot_id = slot->id(); - } + if (_use_vector_index && !_use_ivfpq && slot->id() == _vector_slot_id) { + index = _tablet_schema->num_columns(); + _params.vector_search_option->vector_column_id = index; + _params.vector_search_option->vector_slot_id = slot->id(); } else { index = _tablet_schema->field_index(slot->col_name()); } @@ -352,12 +351,7 @@ Status OlapChunkSource::_init_scanner_columns(std::vector& scanner_col Status OlapChunkSource::_init_unused_output_columns(const std::vector& unused_output_columns) { for (const auto& col_name : unused_output_columns) { - int32_t index; - if (_use_vector_index && !_use_ivfpq) { - index = _tablet_schema->field_index(col_name, _vector_distance_column_name); - } else { - index = _tablet_schema->field_index(col_name); - } + int32_t index = _tablet_schema->field_index(col_name); if (index < 0) { std::stringstream ss; ss << "invalid field name: " << col_name; @@ -562,8 +556,8 @@ Status OlapChunkSource::_init_global_dicts(TabletReaderParams* params) { if (iter != global_dict_map.end()) { auto& dict_map = iter->second.first; int32_t index; - if (_use_vector_index && !_use_ivfpq) { - index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); + if (_use_vector_index && !_use_ivfpq && slot->id() == _vector_slot_id) { + index = _tablet_schema->num_columns(); } else { index = _tablet_schema->field_index(slot->col_name()); } diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.h b/be/src/exec/pipeline/scan/olap_chunk_source.h index 8186a1ac09594..6057bbd17e253 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.h +++ b/be/src/exec/pipeline/scan/olap_chunk_source.h @@ -106,10 +106,9 @@ class OlapChunkSource final : public ChunkSource { std::vector _column_access_paths; bool _use_vector_index = false; - bool _use_ivfpq = false; - std::string _vector_distance_column_name; + SlotId _vector_slot_id; // The following are profile meatures int64_t _num_rows_read = 0; diff --git a/be/src/storage/index/vector/tenann/del_id_filter.h b/be/src/storage/index/vector/tenann/del_id_filter.h index b23b3e245fdd5..d8b36a96a5aba 100644 --- a/be/src/storage/index/vector/tenann/del_id_filter.h +++ b/be/src/storage/index/vector/tenann/del_id_filter.h @@ -31,6 +31,9 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +#pragma once + #ifdef WITH_TENANN #include "storage/del_vector.h" #include "storage/range.h" @@ -39,10 +42,10 @@ namespace starrocks { -class DelIdFilter : public tenann::IdFilter { +class DelIdFilter final : public tenann::IdFilter { public: - DelIdFilter(const SparseRange<>& scan_range); - ~DelIdFilter() = default; + explicit DelIdFilter(const SparseRange<>& scan_range); + ~DelIdFilter() override = default; bool IsMember(tenann::idx_t id) const override; @@ -51,4 +54,4 @@ class DelIdFilter : public tenann::IdFilter { }; } // namespace starrocks -#endif \ No newline at end of file +#endif diff --git a/be/src/storage/index/vector/tenann/tenann_index_builder.cpp b/be/src/storage/index/vector/tenann/tenann_index_builder.cpp index 991a64767563f..84b07655ac4c2 100644 --- a/be/src/storage/index/vector/tenann/tenann_index_builder.cpp +++ b/be/src/storage/index/vector/tenann/tenann_index_builder.cpp @@ -25,6 +25,7 @@ #include "tenann/factory/index_factory.h" namespace starrocks { + // =============== TenAnnIndexBuilderProxy ============= Status TenAnnIndexBuilderProxy::init() { @@ -35,10 +36,19 @@ Status TenAnnIndexBuilderProxy::init() { return Status::OK(); }).status()); - if (!meta.common_params().contains("dim")) { - return Status::InvalidArgument("Dim is needed because it's a critical common param"); + const auto& params = meta.common_params(); + + if (!params.contains(index::vector::DIM)) { + return Status::InvalidArgument("dim is needed because it's a critical common param"); + } + _dim = params[index::vector::DIM]; + + if (!params.contains(index::vector::METRIC_TYPE)) { + return Status::InvalidArgument("metric_type is needed because it's a critical common param"); } - _dim = meta.common_params()["dim"]; + _is_input_normalized = params.contains(index::vector::IS_VECTOR_NORMED) && + params[index::vector::IS_VECTOR_NORMED] && + params[index::vector::METRIC_TYPE] == tenann::MetricType::kCosineSimilarity; auto meta_copy = meta; if (meta.index_type() == tenann::IndexType::kFaissIvfPq && config::enable_vector_index_block_cache) { @@ -51,7 +61,7 @@ Status TenAnnIndexBuilderProxy::init() { // build and write index _index_builder = tenann::IndexFactory::CreateBuilderFromMeta(meta_copy); _index_builder->index_writer()->SetIndexCache(tenann::IndexCache::GetGlobalInstance()); - if (_src_is_nullable) { + if (_is_element_nullable) { _index_builder->EnableCustomRowId(); } _index_builder->Open(_index_path); @@ -66,98 +76,64 @@ Status TenAnnIndexBuilderProxy::init() { return Status::OK(); } -Status TenAnnIndexBuilderProxy::add(const Column& data) { - try { - auto vector_view = tenann::ArraySeqView{.data = const_cast(data.raw_data()), - .dim = _dim, - .size = static_cast(data.size()), - .elem_type = tenann::kFloatType}; - if (data.is_array() && data.size() != 0) { - const auto& cur_array = down_cast(data); - auto offsets = cur_array.offsets(); - size_t last_offset = 0; - auto* offsets_data = reinterpret_cast(offsets.mutable_raw_data()); - for (size_t i = 1; i < offsets.size(); i++) { - size_t dim = offsets_data[i] - last_offset; - if (dim > 0 && _dim != dim) { - LOG(WARNING) << "index dim: " << _dim << ", written dim: " << dim; - return Status::InvalidArgument( - strings::Substitute("The dimensions of the vector written are inconsistent, index dim is " - "$0 but data dim is $1, vector data is ", - _dim, dim)); - } - last_offset = offsets_data[i]; - } - } - - _index_builder->Add({vector_view}); - } catch (tenann::Error& e) { - LOG(WARNING) << e.what(); - return Status::InternalError(e.what()); +template +static Status valid_input_vector(const ArrayColumn& input_column, const size_t index_dim) { + if (input_column.empty()) { + return Status::OK(); } - return Status::OK(); -} -Status TenAnnIndexBuilderProxy::add(const Column& data, const Column& null_map, const size_t offset) { - try { - auto vector_view = tenann::ArraySeqView{.data = const_cast(data.raw_data()), - .dim = _dim, - .size = static_cast(data.size()), - .elem_type = tenann::kFloatType}; - if (data.is_array() && data.size() != 0) { - const auto& cur_array = down_cast(data); - auto offsets = cur_array.offsets(); - size_t last_offset = 0; - auto* offsets_data = reinterpret_cast(offsets.mutable_raw_data()); - for (size_t i = 1; i < offsets.size(); i++) { - size_t dim = offsets_data[i] - last_offset; - if (dim > 0 && _dim != dim) { - LOG(WARNING) << "index dim: " << _dim << ", written dim: " << dim; - return Status::InvalidArgument( - strings::Substitute("The dimensions of the vector written are inconsistent, index dim is " - "$0 but data dim is $1, vector data is ", - _dim, dim)); - } - last_offset = offsets_data[i]; - } + const size_t num_rows = input_column.size(); + const auto* offsets = reinterpret_cast(input_column.offsets().raw_data()); + const auto* nums = reinterpret_cast(input_column.elements().raw_data()); + + for (size_t i = 0; i < num_rows; i++) { + const size_t input_dim = offsets[i + 1] - offsets[i]; + + if (input_dim != index_dim) { + return Status::InvalidArgument( + strings::Substitute("The dimensions of the vector written are inconsistent, index dim is " + "$0 but data dim is $1", + index_dim, input_dim)); } - std::vector row_ids(data.size()); - std::iota(row_ids.begin(), row_ids.end(), offset); - _index_builder->Add({vector_view}, row_ids.data(), null_map.raw_data()); - } catch (tenann::Error& e) { - LOG(WARNING) << e.what(); - return Status::InternalError(e.what()); + if constexpr (is_input_normalized) { + double sum = 0; + for (int j = 0; j < input_dim; j++) { + sum += nums[offsets[i] + j] * nums[offsets[i] + j]; + } + if (std::abs(sum - 1) > 1e-6) { + return Status::InvalidArgument( + "The input vector is not normalized but `metric_type` is cosine_similarity and " + "`is_vector_normed` is true"); + } + } } + return Status::OK(); } -Status TenAnnIndexBuilderProxy::write(const Column& data) { - try { - auto vector_view = tenann::ArraySeqView{.data = const_cast(data.raw_data()), - .dim = _dim, - .size = static_cast(data.size()), - .elem_type = tenann::kFloatType}; +Status TenAnnIndexBuilderProxy::add(const Column& array_column, const size_t offset) { + DCHECK(array_column.is_array()); + const auto& array_col = down_cast(array_column); - _index_builder->Add({vector_view}); - } catch (tenann::Error& e) { - LOG(WARNING) << e.what(); - return Status::InternalError(e.what()); + DCHECK(array_col.elements_column()->is_nullable()); + const auto& nullable_elements = down_cast(array_col.elements()); + const auto& is_element_nulls = nullable_elements.null_column_ref(); + + if (_is_input_normalized) { + RETURN_IF_ERROR(valid_input_vector(array_col, _dim)); + } else { + RETURN_IF_ERROR(valid_input_vector(array_col, _dim)); } - return Status::OK(); -} -Status TenAnnIndexBuilderProxy::write(const Column& data, const Column& null_map) { try { - auto vector_view = tenann::ArraySeqView{.data = const_cast(data.raw_data()), + auto vector_view = tenann::ArraySeqView{.data = const_cast(array_col.raw_data()), .dim = _dim, - .size = static_cast(data.size()), + .size = static_cast(array_col.size()), .elem_type = tenann::kFloatType}; - - std::vector row_ids(data.size()); - std::iota(row_ids.begin(), row_ids.end(), 0); - _index_builder->Add({vector_view}, row_ids.data(), null_map.raw_data()); - + std::vector row_ids(array_col.size()); + std::iota(row_ids.begin(), row_ids.end(), offset); + _index_builder->Add({vector_view}, row_ids.data(), is_element_nulls.raw_data()); } catch (tenann::Error& e) { LOG(WARNING) << e.what(); return Status::InternalError(e.what()); @@ -175,10 +151,11 @@ Status TenAnnIndexBuilderProxy::flush() { return Status::OK(); } -void TenAnnIndexBuilderProxy::close() { +void TenAnnIndexBuilderProxy::close() const { if (_index_builder && !_index_builder->is_closed()) { _index_builder->Close(); } } + } // namespace starrocks #endif diff --git a/be/src/storage/index/vector/tenann/tenann_index_builder.h b/be/src/storage/index/vector/tenann/tenann_index_builder.h index 3f53affa555e2..7d29e11f1ad3c 100644 --- a/be/src/storage/index/vector/tenann/tenann_index_builder.h +++ b/be/src/storage/index/vector/tenann/tenann_index_builder.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #ifdef WITH_TENANN #include @@ -24,35 +26,34 @@ namespace starrocks { // A proxy to real Ten ANN index builder -class TenAnnIndexBuilderProxy : public VectorIndexBuilder { +class TenAnnIndexBuilderProxy final : public VectorIndexBuilder { public: TenAnnIndexBuilderProxy(std::shared_ptr tablet_index, std::string segment_index_path, - bool src_is_nullable) + bool is_element_nullable) : VectorIndexBuilder(std::move(tablet_index), std::move(segment_index_path)), - _src_is_nullable(src_is_nullable){}; + _is_element_nullable(is_element_nullable) {} // proxy should not clean index builder resource ~TenAnnIndexBuilderProxy() override { close(); }; Status init() override; - Status add(const Column& data) override; - - Status add(const Column& data, const Column& null_map, const size_t offset) override; - - Status write(const Column& data) override; - - Status write(const Column& data, const Column& null_map) override; + Status add(const Column& array_column, const size_t offset) override; Status flush() override; - void close(); + void close() const; private: - std::shared_ptr _index_builder; - uint32_t _dim = 0; OnceFlag _init_once; - bool _src_is_nullable; + std::shared_ptr _index_builder = nullptr; + uint32_t _dim = 0; + // This will be true when `metric_type` is cosine_similarity and `is_vector_normed` is true. + // When it is true, the vector (a row of the array column) is either null or the sum of the squares of all elements + // equals 1. + bool _is_input_normalized = false; + + const bool _is_element_nullable; }; } // namespace starrocks diff --git a/be/src/storage/index/vector/tenann/tenann_index_utils.cpp b/be/src/storage/index/vector/tenann/tenann_index_utils.cpp index 18568650d31a7..35ccb41eb7c62 100644 --- a/be/src/storage/index/vector/tenann/tenann_index_utils.cpp +++ b/be/src/storage/index/vector/tenann/tenann_index_utils.cpp @@ -80,17 +80,16 @@ StatusOr get_vector_meta(const std::shared_ptr& if (meta.index_type() == tenann::IndexType::kFaissIvfPq) { meta.index_params()[starrocks::index::vector::NLIST] = int(4 * sqrt(starrocks::index::vector::nb_)); - CRITICAL_CHECK_AND_GET(tablet_index, index_properties, M, param_value) - meta.index_params()[starrocks::index::vector::M] = std::atoi(param_value.c_str()); - CRITICAL_CHECK_AND_GET(tablet_index, index_properties, nbits, param_value) meta.index_params()[starrocks::index::vector::NBITS] = std::atoi(param_value.c_str()); + CRITICAL_CHECK_AND_GET(tablet_index, index_properties, m_ivfpq, param_value) + meta.index_params()[starrocks::index::vector::M] = std::atoi(param_value.c_str()); } else if (meta.index_type() == tenann::IndexType::kFaissHnsw) { CRITICAL_CHECK_AND_GET(tablet_index, index_properties, efconstruction, param_value) meta.index_params()[starrocks::index::vector::EF_CONSTRUCTION] = std::atoi(param_value.c_str()); - CRITICAL_CHECK_AND_GET(tablet_index, index_properties, M, param_value) + CRITICAL_CHECK_AND_GET(tablet_index, index_properties, m, param_value) meta.index_params()[starrocks::index::vector::M] = std::atoi(param_value.c_str()); GET_OR_DEFAULT(tablet_index, search_properties, efsearch, param_value, "40") diff --git a/be/src/storage/index/vector/vector_index_builder.h b/be/src/storage/index/vector/vector_index_builder.h index 1975551862249..60fa6458558e4 100644 --- a/be/src/storage/index/vector/vector_index_builder.h +++ b/be/src/storage/index/vector/vector_index_builder.h @@ -29,17 +29,7 @@ class VectorIndexBuilder { // init from builder meta virtual Status init() = 0; - // add not null data - virtual Status add(const Column& data) = 0; - - // add data contains null - virtual Status add(const Column& data, const Column& null_map, const size_t offset) = 0; - - // write not null data - virtual Status write(const Column& data) = 0; - - // write data contains nulls - virtual Status write(const Column& data, const Column& null_map) = 0; + virtual Status add(const Column& array_column, const size_t offset) = 0; // flush data into disk virtual Status flush() = 0; @@ -56,20 +46,14 @@ class VectorIndexBuilder { std::string _index_path; }; -class EmptyVectorIndexBuilder : public VectorIndexBuilder { +class EmptyVectorIndexBuilder final : public VectorIndexBuilder { public: EmptyVectorIndexBuilder(std::shared_ptr tablet_index, std::string segment_index_path) : VectorIndexBuilder(std::move(tablet_index), std::move(segment_index_path)){}; Status init() override { return Status::OK(); } - Status add(const Column& data) override { return Status::OK(); } - - Status add(const Column& data, const Column& null_map, const size_t offset) override { return Status::OK(); } - - Status write(const Column& data) override { return Status::OK(); } - - Status write(const Column& data, const Column& null_map) override { return Status::OK(); } + Status add(const Column& array_column, const size_t offset) override { return Status::OK(); } Status flush() override { RETURN_IF_ERROR(VectorIndexBuilder::flush_empty(_index_path)); diff --git a/be/src/storage/index/vector/vector_index_builder_factory.cpp b/be/src/storage/index/vector/vector_index_builder_factory.cpp index 6ef43054bc613..3dc3ed7ef83cd 100644 --- a/be/src/storage/index/vector/vector_index_builder_factory.cpp +++ b/be/src/storage/index/vector/vector_index_builder_factory.cpp @@ -21,11 +21,11 @@ namespace starrocks { // =============== IndexBuilderFactory ============= StatusOr> VectorIndexBuilderFactory::create_index_builder( const std::shared_ptr& tablet_index, const std::string& segment_index_path, - const IndexBuilderType index_builder_type, const bool src_is_nullable) { + const IndexBuilderType index_builder_type, const bool is_element_nullable) { switch (index_builder_type) { case TEN_ANN: #ifdef WITH_TENANN - return std::make_unique(tablet_index, segment_index_path, src_is_nullable); + return std::make_unique(tablet_index, segment_index_path, is_element_nullable); #else return std::make_unique(tablet_index, segment_index_path); #endif diff --git a/be/src/storage/index/vector/vector_index_builder_factory.h b/be/src/storage/index/vector/vector_index_builder_factory.h index f2328b7f5dd43..53a94a9b892b3 100644 --- a/be/src/storage/index/vector/vector_index_builder_factory.h +++ b/be/src/storage/index/vector/vector_index_builder_factory.h @@ -28,7 +28,7 @@ class VectorIndexBuilderFactory { public: static StatusOr> create_index_builder( const std::shared_ptr& tablet_index, const std::string& segment_index_path, - const IndexBuilderType index_builder_type, const bool src_is_nullable); + const IndexBuilderType index_builder_type, const bool is_element_nullable); static StatusOr get_index_builder_type_from_config(std::shared_ptr _tablet_index) { return IndexBuilderType::TEN_ANN; diff --git a/be/src/storage/index/vector/vector_index_writer.cpp b/be/src/storage/index/vector/vector_index_writer.cpp index 64c6a0c1c0ad1..6ca49599b1611 100644 --- a/be/src/storage/index/vector/vector_index_writer.cpp +++ b/be/src/storage/index/vector/vector_index_writer.cpp @@ -21,9 +21,9 @@ namespace starrocks { void VectorIndexWriter::create(const std::shared_ptr& tablet_index, - const std::string& vector_index_file_path, bool is_nullable, + const std::string& vector_index_file_path, bool is_element_nullable, std::unique_ptr* res) { - (*res) = std::make_unique(tablet_index, vector_index_file_path, is_nullable); + *res = std::make_unique(tablet_index, vector_index_file_path, is_element_nullable); } Status VectorIndexWriter::init() { @@ -44,29 +44,24 @@ Status VectorIndexWriter::append(const Column& src) { int64_t duration = 0; { SCOPED_RAW_TIMER(&duration); - if (_index_builder.get() == nullptr) { + + if (_index_builder == nullptr) { if (_row_size + src.size() >= _start_vector_index_build_threshold) { RETURN_IF_ERROR(_prepare_index_builder()); } else { - if (!_buffer_column.get()) { - if (is_nullable()) { - if (src.is_nullable()) { - _buffer_column = std::make_unique(down_cast(src)); - } else { - _buffer_column = NullableColumn::wrap_if_necessary(src.clone_shared()); - } - } else { - _buffer_column = std::make_unique(down_cast(src)); - } + if (_buffer_column == nullptr) { + _buffer_column = src.clone_shared(); } else { _buffer_column->append(src, 0, src.size()); } } } - if (_index_builder.get() != nullptr) { + + if (_index_builder != nullptr) { RETURN_IF_ERROR(_append_data(src, _next_row_id)); } } + _next_row_id += src.size(); _buffer_size += src.byte_size(); _row_size += src.size(); @@ -111,11 +106,12 @@ uint64_t VectorIndexWriter::estimate_buffer_size() const { Status VectorIndexWriter::_prepare_index_builder() { ASSIGN_OR_RETURN(auto index_builder_type, VectorIndexBuilderFactory::get_index_builder_type_from_config(_tablet_index)) - ASSIGN_OR_RETURN(_index_builder, VectorIndexBuilderFactory::create_index_builder( - _tablet_index, _vector_index_file_path, index_builder_type, _is_nullable)) + ASSIGN_OR_RETURN(_index_builder, + VectorIndexBuilderFactory::create_index_builder(_tablet_index, _vector_index_file_path, + index_builder_type, _is_element_nullable)); RETURN_IF_ERROR(_index_builder->init()); - if (_buffer_column.get()) { + if (_buffer_column != nullptr) { RETURN_IF_ERROR(_append_data(*_buffer_column, 0)); _buffer_column.reset(); } @@ -123,19 +119,8 @@ Status VectorIndexWriter::_prepare_index_builder() { } Status VectorIndexWriter::_append_data(const Column& src, size_t offset) { - if (is_nullable()) { - if (src.is_nullable()) { - auto nullable_column = down_cast(src); - const auto& data_column_ref = nullable_column.data_column_ref(); - const auto& null_column_ref = nullable_column.null_column_ref(); - RETURN_IF_ERROR(_index_builder->add(data_column_ref, null_column_ref, offset)); - } else { - auto empty_null_ptr = NullColumn::create(src.size(), 0); - RETURN_IF_ERROR(_index_builder->add(src, *empty_null_ptr, offset)); - } - } else { - RETURN_IF_ERROR(_index_builder->add(src)); - } + DCHECK(src.is_array()); + RETURN_IF_ERROR(_index_builder->add(src, offset)); return Status::OK(); } diff --git a/be/src/storage/index/vector/vector_index_writer.h b/be/src/storage/index/vector/vector_index_writer.h index 20b4556e7a8a1..73bd050f60cc9 100644 --- a/be/src/storage/index/vector/vector_index_writer.h +++ b/be/src/storage/index/vector/vector_index_writer.h @@ -27,20 +27,22 @@ #include "types/bitmap_value.h" namespace starrocks { + class ArrayColumn; + class VectorIndexWriter { public: static void create(const std::shared_ptr& tablet_index, const std::string& vector_index_file_path, - bool is_nullable, std::unique_ptr* res); - - VectorIndexWriter() = default; - ~VectorIndexWriter() = default; + bool is_element_nullable, std::unique_ptr* res); VectorIndexWriter(const std::shared_ptr& tablet_index, std::string vector_index_file_path, - bool is_nullable) + bool is_element_nullable) : _tablet_index(tablet_index), _vector_index_file_path(std::move(vector_index_file_path)), - _is_nullable(is_nullable){}; + _is_element_nullable(is_element_nullable) { + // Element of array column must be nullable. + DCHECK(_is_element_nullable); + } Status init(); @@ -54,8 +56,6 @@ class VectorIndexWriter { uint64_t total_mem_footprint() const { return estimate_buffer_size(); } - bool is_nullable() const { return _is_nullable; } - private: std::shared_ptr _tablet_index; std::string _vector_index_file_path; @@ -68,7 +68,7 @@ class VectorIndexWriter { // size of null_bit column is the same size with buffer_column // e.g. buffer_column: [1, NULL, 3, NULL, 4], null_column: [0, 1, 0, 1, 0] - bool _is_nullable; + const bool _is_element_nullable; size_t _next_row_id = 0; size_t _row_size = 0; size_t _buffer_size = 0; diff --git a/be/src/storage/rowset/array_column_writer.cpp b/be/src/storage/rowset/array_column_writer.cpp index 6fae8c973da00..57ec1265b38be 100644 --- a/be/src/storage/rowset/array_column_writer.cpp +++ b/be/src/storage/rowset/array_column_writer.cpp @@ -133,7 +133,8 @@ ArrayColumnWriter::ArrayColumnWriter(const ColumnWriterOptions& opts, TypeInfoPt DCHECK(_opts.tablet_index.count(IndexType::VECTOR) > 0); auto tablet_index = std::make_shared(_opts.tablet_index.at(IndexType::VECTOR)); std::string index_path = _opts.standalone_index_file_paths.at(IndexType::VECTOR); - VectorIndexWriter::create(tablet_index, index_path, is_nullable(), &_vector_index_writer); + // Element column of array column MUST BE nullable. + VectorIndexWriter::create(tablet_index, index_path, true, &_vector_index_writer); } } @@ -174,6 +175,8 @@ Status ArrayColumnWriter::append(const Column& column) { // 4. write vector index if (_vector_index_writer.get()) { + // Vector index only support non-nullable array column. + DCHECK(!is_nullable()); RETURN_IF_ERROR(_vector_index_writer->append(*array_column)); } diff --git a/be/src/storage/rowset/segment_iterator.cpp b/be/src/storage/rowset/segment_iterator.cpp index c4bc486b9ab15..801193c45ddc1 100644 --- a/be/src/storage/rowset/segment_iterator.cpp +++ b/be/src/storage/rowset/segment_iterator.cpp @@ -413,7 +413,7 @@ SegmentIterator::SegmentIterator(std::shared_ptr segment, Schema schema _result_order = _opts.vector_search_option->result_order; _use_ivfpq = _opts.vector_search_option->use_ivfpq; _query_params = _opts.vector_search_option->query_params; - if (_vector_range > 0 && _use_ivfpq) { + if (_vector_range >= 0 && _use_ivfpq) { _k = _opts.vector_search_option->k * _opts.vector_search_option->pq_refine_factor * _opts.vector_search_option->k_factor; } else { @@ -611,7 +611,7 @@ Status SegmentIterator::_get_row_ranges_by_vector_index() { { SCOPED_RAW_TIMER(&_opts.stats->vector_search_timer); - if (_vector_range > 0) { + if (_vector_range >= 0) { st = _ann_reader->range_search(_query_view, _k, &result_ids, &result_distances, &del_id_filter, static_cast(_vector_range), _result_order); } else { diff --git a/fe/fe-core/src/main/java/com/starrocks/alter/SchemaChangeHandler.java b/fe/fe-core/src/main/java/com/starrocks/alter/SchemaChangeHandler.java index 2ff33e541d59c..98aeecc459ff0 100644 --- a/fe/fe-core/src/main/java/com/starrocks/alter/SchemaChangeHandler.java +++ b/fe/fe-core/src/main/java/com/starrocks/alter/SchemaChangeHandler.java @@ -2080,7 +2080,7 @@ public AlterJobV2 createAlterMetaJob(AlterClause alterClause, Database db, OlapT if (oldEnablePersistentIndex == enablePersistentIndex && persistentIndexType == oldPersistentIndexType) { LOG.info(String.format("table: %s enable_persistent_index is %s persistent_index_type is %s, " - + "nothing need to do", olapTable.getName(), enablePersistentIndex, persistentIndexType)); + + "nothing need to do", olapTable.getName(), enablePersistentIndex, persistentIndexType)); return null; } if (properties.containsKey(PropertyAnalyzer.PROPERTIES_PERSISTENT_INDEX_TYPE) @@ -2140,10 +2140,10 @@ public ShowResultSet processLakeTableAlterMeta(AlterClause alterClause, Database public void processLakeTableDropPersistentIndex(AlterClause alterClause, Database db, OlapTable olapTable) throws StarRocksException { - if (!olapTable.enablePersistentIndex() || + if (!olapTable.enablePersistentIndex() || olapTable.getPersistentIndexType() != TPersistentIndexType.CLOUD_NATIVE) { LOG.warn(String.format("drop persistent index on table %s failed, it must be" + - " cloud_native persistent index", olapTable.getName())); + " cloud_native persistent index", olapTable.getName())); throw new DdlException("drop persistent index only support cloud native index"); } Set dropPindexTablets = ((DropPersistentIndexClause) alterClause).getTabletIds(); @@ -2165,7 +2165,7 @@ public void processLakeTableDropPersistentIndex(AlterClause alterClause, Databas LOG.warn(String.format("drop persistent index on tablet %d failed, error: %s", tabletId, e.getMessage())); throw new DdlException(String.format("drop persistent index on tablet %d failed, error: %s", - tabletId, e.getMessage())); + tabletId, e.getMessage())); } } } @@ -2175,7 +2175,7 @@ public void updateTableMeta(Database db, String tableName, Map p throws DdlException { List partitions = Lists.newArrayList(); OlapTable olapTable = (OlapTable) GlobalStateMgr.getCurrentState().getLocalMetastore() - .getTable(db.getFullName(), tableName); + .getTable(db.getFullName(), tableName); Locker locker = new Locker(); locker.lockTablesWithIntensiveDbLock(db.getId(), Lists.newArrayList(olapTable.getId()), LockType.READ); @@ -2370,7 +2370,7 @@ public void updatePartitionsInMemoryMeta(Database db, List partitionNames, Map properties) throws DdlException { OlapTable olapTable = (OlapTable) GlobalStateMgr.getCurrentState().getLocalMetastore() - .getTable(db.getFullName(), tableName); + .getTable(db.getFullName(), tableName); Locker locker = new Locker(); locker.lockTablesWithIntensiveDbLock(db.getId(), Lists.newArrayList(olapTable.getId()), LockType.READ); try { @@ -2414,7 +2414,7 @@ public void updateBinlogPartitionTabletMeta(Database db, // be id -> Set Map> beIdToTabletId = Maps.newHashMap(); OlapTable olapTable = (OlapTable) GlobalStateMgr.getCurrentState().getLocalMetastore() - .getTable(db.getFullName(), tableName); + .getTable(db.getFullName(), tableName); Locker locker = new Locker(); locker.lockTablesWithIntensiveDbLock(db.getId(), Lists.newArrayList(olapTable.getId()), LockType.READ); @@ -2501,7 +2501,7 @@ public void updatePartitionTabletMeta(Database db, // be id -> Map> beIdToTabletSet = Maps.newHashMap(); OlapTable olapTable = (OlapTable) GlobalStateMgr.getCurrentState().getLocalMetastore() - .getTable(db.getFullName(), tableName); + .getTable(db.getFullName(), tableName); Locker locker = new Locker(); locker.lockTablesWithIntensiveDbLock(db.getId(), Lists.newArrayList(olapTable.getId()), LockType.READ); @@ -2720,6 +2720,16 @@ private void processAddIndex(CreateIndexClause alterClause, OlapTable olapTable, throw new SemanticException("GIN does not support replicated mode"); } + if (newIndex.getIndexType() == IndexType.VECTOR) { + Optional oldVectorIndex = + newIndexes.stream().filter(index -> index.getIndexType() == IndexType.VECTOR).findFirst(); + if (oldVectorIndex.isPresent()) { + throw new SemanticException( + String.format("At most one vector index is allowed for a table, but there is already a vector index [%s]", + oldVectorIndex.get().getIndexName())); + } + } + List existedIndexes = olapTable.getIndexes(); IndexDef indexDef = alterClause.getIndexDef(); Set newColset = Sets.newTreeSet(String.CASE_INSENSITIVE_ORDER); diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/VectorIndexUtil.java b/fe/fe-core/src/main/java/com/starrocks/analysis/VectorIndexUtil.java index f8f79fafc3bd7..b0d436e2b35b6 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/VectorIndexUtil.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/VectorIndexUtil.java @@ -47,8 +47,14 @@ public static void checkVectorIndexValid(Column column, Map prop throw new SemanticException("The vector index does not support shared data mode"); } if (!Config.enable_experimental_vector) { - throw new SemanticException("The vector index is disabled, enable it by setting FE config `enable_experimental_vector` to true"); + throw new SemanticException( + "The vector index is disabled, enable it by setting FE config `enable_experimental_vector` to true"); } + + if (column.isAllowNull()) { + throw new SemanticException("The vector index can only build on non-nullable column"); + } + // Only support create vector index on DUPLICATE/PRIMARY table or key columns of UNIQUE/AGGREGATE table. if (keysType != KeysType.DUP_KEYS && keysType != KeysType.PRIMARY_KEYS) { throw new SemanticException("The vector index can only build on DUPLICATE or PRIMARY table"); @@ -67,8 +73,10 @@ public static void checkVectorIndexValid(Column column, Map prop // check param keys which must not be null Map mustNotNullParams = IndexParams.getInstance().getMustNotNullParams(IndexType.VECTOR); - Map indexIndexParams = IndexParams.getInstance().getKeySetByIndexTypeAndParamType(IndexType.VECTOR, IndexParamType.INDEX); - Map searchIndexParams = IndexParams.getInstance().getKeySetByIndexTypeAndParamType(IndexType.VECTOR, IndexParamType.SEARCH); + Map indexIndexParams = + IndexParams.getInstance().getKeySetByIndexTypeAndParamType(IndexType.VECTOR, IndexParamType.INDEX); + Map searchIndexParams = + IndexParams.getInstance().getKeySetByIndexTypeAndParamType(IndexType.VECTOR, IndexParamType.SEARCH); Map> indexParamsGroupByType = Arrays.stream(IndexParamsKey.values()).filter(belong -> belong.getBelongVectorIndexType() != null) @@ -76,7 +84,8 @@ public static void checkVectorIndexValid(Column column, Map prop Collectors.mapping(Enum::name, Collectors.toSet()))); Map> searchParamsGroupByType = - Arrays.stream(VectorIndexParams.SearchParamsKey.values()).filter(belong -> belong.getBelongVectorIndexType() != null) + Arrays.stream(VectorIndexParams.SearchParamsKey.values()) + .filter(belong -> belong.getBelongVectorIndexType() != null) .collect(Collectors.groupingBy(SearchParamsKey::getBelongVectorIndexType, Collectors.mapping(Enum::name, Collectors.toSet()))); @@ -119,7 +128,6 @@ public static void checkVectorIndexValid(Column column, Map prop configSearchParams.removeAll(Optional.ofNullable(searchParamsGroupByType.get(vectorIndexType)) .orElse(Collections.emptySet())); - if (!configIndexParams.isEmpty()) { throw new SemanticException(String.format("Index params %s should not define with %s", configIndexParams, vectorIndexType)); @@ -130,10 +138,33 @@ public static void checkVectorIndexValid(Column column, Map prop vectorIndexType)); } + if (vectorIndexType == VectorIndexType.IVFPQ) { + String m = properties.get(IndexParamsKey.M_IVFPQ.name().toUpperCase()); + if (m == null) { + throw new SemanticException("`M_IVFPQ` is required for IVFPQ index"); + } + // m is a valid integer which is guaranteed by checkParams. + int mValue = Integer.parseInt(m); + + String dim = properties.get(CommonIndexParamKey.DIM.name().toUpperCase()); + int dimValue = Integer.parseInt(dim); + if (dimValue % mValue != 0) { + throw new SemanticException("`DIM` should be a multiple of `M_IVFPQ` for IVFPQ index"); + } + } + // add default properties + Set indexParams = indexParamsGroupByType.get(vectorIndexType); + paramsNeedDefault.keySet().removeIf(key -> !indexParams.contains(key)); if (!paramsNeedDefault.isEmpty()) { addDefaultProperties(properties, paramsNeedDefault); } + + // Lower all the keys and values of properties. + Map lowerProperties = properties.entrySet().stream() + .collect(Collectors.toMap(entry -> entry.getKey().toLowerCase(), entry -> entry.getValue().toLowerCase())); + properties.clear(); + properties.putAll(lowerProperties); } private static void addDefaultProperties(Map properties, Map paramsNeedDefault) { diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index 60810d6ec29dc..6fbf0c0d87e8e 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -157,7 +157,6 @@ public class FunctionSet { // Vector Index functions: public static final String APPROX_COSINE_SIMILARITY = "approx_cosine_similarity"; - public static final String APPROX_COSINE_SIMILARITY_NORM = "approx_cosine_similarity_norm"; public static final String APPROX_L2_DISTANCE = "approx_l2_distance"; // Geo functions: @@ -652,7 +651,6 @@ public class FunctionSet { public static final Set VECTOR_COMPUTE_FUNCTIONS = ImmutableSet.builder() .add(APPROX_COSINE_SIMILARITY) - .add(APPROX_COSINE_SIMILARITY_NORM) .add(APPROX_L2_DISTANCE) .build(); diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/IndexParams.java b/fe/fe-core/src/main/java/com/starrocks/catalog/IndexParams.java index 2adaf37945d46..0dbe8e65dc513 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/IndexParams.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/IndexParams.java @@ -30,7 +30,6 @@ import java.util.Locale; import java.util.Map; import java.util.Map.Entry; -import java.util.Optional; import java.util.stream.Collectors; public class IndexParams { @@ -62,9 +61,11 @@ private IndexParams() { register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.M, false, true, "16", null); register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.EFCONSTRUCTION, false, true, "40", null); - register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.NBITS, false, false, "8", + register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.NBITS, false, true, "8", null); - register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.NLIST, false, false, null, + register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.NLIST, false, true, "16", + null); + register(builder, IndexType.VECTOR, IndexParamType.INDEX, VectorIndexParams.IndexParamsKey.M_IVFPQ, false, false, null, null); // search @@ -168,7 +169,11 @@ public Map getMustNotNullParams(IndexType indexType) { } public void checkParams(String key, String value) throws SemanticException { - Optional.ofNullable(paramsHolder.get(key)).ifPresent(p -> p.checkValue(value)); + IndexParamItem item = paramsHolder.get(key); + if (item == null) { + throw new SemanticException("Unknown index param: `" + key + "`"); + } + item.checkValue(value); } public enum IndexParamType { diff --git a/fe/fe-core/src/main/java/com/starrocks/common/VectorIndexParams.java b/fe/fe-core/src/main/java/com/starrocks/common/VectorIndexParams.java index d874d86f2d2dd..41f73c2279cca 100644 --- a/fe/fe-core/src/main/java/com/starrocks/common/VectorIndexParams.java +++ b/fe/fe-core/src/main/java/com/starrocks/common/VectorIndexParams.java @@ -23,7 +23,6 @@ import java.util.Set; import java.util.stream.Collectors; - public class VectorIndexParams { public enum CommonIndexParamKey implements ParamsKey { @@ -44,13 +43,7 @@ public void check(String value) { DIM { @Override public void check(String value) { - if (!StringUtils.isNumeric(value)) { - throw new SemanticException("Value of `DIM` must be a number"); - } - int dim = Integer.parseInt(value); - if (dim <= 0) { - throw new SemanticException("Value of `DIM` must greater then 0"); - } + validateInteger(value, "DIM", 1); } }, // Vector space metrics method, the enumeration of values refer to MetricsType @@ -67,7 +60,14 @@ public void check(String value) { }, // Whether vector should be normed - IS_VECTOR_NORMED, + IS_VECTOR_NORMED { + @Override + public void check(String value) { + if (!StringUtils.equalsIgnoreCase(value, "true") && !StringUtils.equalsIgnoreCase(value, "false")) { + throw new SemanticException("Value of `IS_VECTOR_NORMED` must be `true` or `false`"); + } + } + }, // Threshold of row number to build index file INDEX_BUILD_THRESHOLD @@ -82,9 +82,7 @@ public enum VectorIndexType { } public enum MetricsType { - COSINE_DISTANCE, COSINE_SIMILARITY, - INNER_PRODUCT, L2_DISTANCE, } @@ -93,30 +91,65 @@ public enum IndexParamsKey implements ParamsKey { // the parameter "M" is a crucial construction parameter that refers to the maximum number of neighbors each node can have at the base layer, // which is the bottommost layer of the graph. - M(VectorIndexType.HNSW), + M(VectorIndexType.HNSW) { + @Override + public void check(String value) { + validateInteger(value, "M", 2); + } + }, // EF_CONSTRUCTION is an important parameter that stands for the construction-time expansion factor. // This parameter controls the depth of the neighbor search for each data point during the index construction process. // Specifically, EF_CONSTRUCTION determines the size of the candidate list for nearest neighbors when inserting new nodes into the HNSW graph. - EFCONSTRUCTION(VectorIndexType.HNSW), + EFCONSTRUCTION(VectorIndexType.HNSW) { + @Override + public void check(String value) { + validateInteger(value, "EFCONSTRUCTION", 1); + } + }, // For IVFPQ // NBITS is a key parameter that refers to the number of bits used to quantize each sub-vector. Within the context of Product Quantization (PQ), // NBITS determines the size of the quantization codebook, which is the number of cluster centers in each quantized subspace. - NBITS(VectorIndexType.IVFPQ), - + NBITS(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + try { + double num = Double.parseDouble(value); + if (num != 8) { + throw new SemanticException(String.format("Value of `%s` must be 8", "NBITS")); + } + } catch (NumberFormatException e) { + throw new SemanticException(String.format("Value of `%s` must be a integer", "NBITS")); + } + } + }, // NLIST is a parameter in the IVF (Inverted File) indexing structure that represents the number of inverted lists, or equivalently, // the number of cluster centers (also known as visual words). // This parameter is set when constructing the index using the k-means clustering algorithm, with the purpose of grouping the vectors in the dataset // around these cluster centers. - NLIST(VectorIndexType.IVFPQ); + NLIST(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "NLIST", 1); + } + }, + + M_IVFPQ(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "M_IVFPQ", 2); + } + }; + + private final VectorIndexType belongVectorIndexType; - private VectorIndexType belongVectorIndexType = null; IndexParamsKey(VectorIndexType vectorIndexType) { belongVectorIndexType = vectorIndexType; } + public VectorIndexType getBelongVectorIndexType() { return belongVectorIndexType; } @@ -127,7 +160,12 @@ public enum SearchParamsKey implements ParamsKey { // The EF_SEARCH parameter represents the size of the dynamic candidate list during the search process, meaning that during the search phase, // the algorithm maintains a priority queue of size ef_search. This queue is used to store the current nearest neighbor candidates and graph // nodes for further exploration. - EFSEARCH(VectorIndexType.HNSW), + EFSEARCH(VectorIndexType.HNSW) { + @Override + public void check(String value) { + validateInteger(value, "EFSEARCH", 1); + } + }, // For IVFPG // NPROBE determines the number of Voronoi cells (or inverted lists) @@ -135,21 +173,36 @@ public enum SearchParamsKey implements ParamsKey { // In IVFPQ, the dataset is first divided into multiple Voronoi cells, and an inverted list is established for each cell. Then, for each query, // we first find the nearest NPROBE Voronoi cells, and then only search in the inverted lists of these cells. // The value of NPROBE affects the accuracy and efficiency of the search. - NPROBE(VectorIndexType.IVFPQ), + NPROBE(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "NPROBE", 1); + } + }, // MAX_CODES determines the maximum number of codes to be inspected during the search phase. // In IVFPQ, the dataset is divided into multiple Voronoi cells, and each cell has an inverted list that contains the codes of all points in the cell. // For each query, we first find the nearest NPROBE Voronoi cells, and then search in the inverted lists of these cells. // However, to control the computational complexity of the search, we usually do not inspect all codes, but only inspect the first maxCodes codes in // each inverted list. The value of maxCodes affects the accuracy and efficiency of the search - MAX_CODES(VectorIndexType.IVFPQ), + MAX_CODES(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "MAX_CODES", 0); + } + }, // SCAN_TABLE_THRESHOLD parameter is used to control the number of entries that are scanned in the lookup table during the search process. // The lookup table is a data structure that stores precomputed distances between the query and the centroids of the quantization cells. // During the search process, the algorithm scans the lookup table to find the nearest centroids to the query. // The SCAN_TABLE_THRESHOLD parameter determines the maximum number of entries that the algorithm will scan in the lookup table. // By adjusting this parameter, one can control the balance between search accuracy and efficiency. - SCAN_TABLE_THRESHOLD(VectorIndexType.IVFPQ), + SCAN_TABLE_THRESHOLD(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "SCAN_TABLE_THRESHOLD", 0); + } + }, // POLYSEMOUS_HT is a parameter related to the Polysemous Coding technique. // Polysemous Coding is a technique used to balance the trade-off between recall and precision in large-scale search problems. @@ -159,19 +212,57 @@ public enum SearchParamsKey implements ParamsKey { // The POLYSEMOUS_HT parameter is a threshold used in the Polysemous Coding technique. It determines the number of "hamming thresholds" // to be used in the search process. The hamming threshold is a measure of similarity between two binary codes: the lower the hamming distance, // the more similar the codes are. - POLYSEMOUS_HT(VectorIndexType.IVFPQ), + POLYSEMOUS_HT(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateInteger(value, "POLYSEMOUS_HT", 0); + } + }, // The RANGE_SEARCH_CONFIDENCE parameter determines the confidence level of the range search in IVFPQ (Inverted File with Product Quantization). // We have developed our own index search algorithm based on the triangle inequality. Adjusting this parameter allows us to control the // performance and accuracy of the range search. - RANGE_SEARCH_CONFIDENCE(VectorIndexType.IVFPQ); + RANGE_SEARCH_CONFIDENCE(VectorIndexType.IVFPQ) { + @Override + public void check(String value) { + validateDouble(value, "RANGE_SEARCH_CONFIDENCE", 0.0, 1.0); + } + }; + private VectorIndexType belongVectorIndexType = null; + SearchParamsKey(VectorIndexType vectorIndexType) { belongVectorIndexType = vectorIndexType; } + public VectorIndexType getBelongVectorIndexType() { return belongVectorIndexType; } } + private static void validateInteger(String value, String key, Integer min) { + try { + int num = Integer.parseInt(value); + if (min != null && num < min) { + throw new SemanticException(String.format("Value of `%s` must be >= %d", key, min)); + } + } catch (NumberFormatException e) { + throw new SemanticException(String.format("Value of `%s` must be a integer", key)); + } + } + + private static void validateDouble(String value, String key, Double min, Double max) { + try { + double num = Double.parseDouble(value); + if (min != null && num < min) { + throw new SemanticException(String.format("Value of `%s` must be >= %f", key, min)); + } + if (max != null && num > max) { + throw new SemanticException(String.format("Value of `%s` must be <= %f", key, max)); + } + } catch (NumberFormatException e) { + throw new SemanticException(String.format("Value of `%s` must be a double", key)); + } + } + } diff --git a/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java b/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java index 7f141a1ed3b42..f0c97523f7b24 100644 --- a/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java +++ b/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java @@ -14,40 +14,27 @@ package com.starrocks.common; -import com.google.gson.annotations.SerializedName; -import com.google.gson.reflect.TypeToken; -import com.starrocks.persist.gson.GsonUtils; +import com.starrocks.thrift.TVectorSearchOptions; -import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; -import java.util.Map; public class VectorSearchOptions { + private static final int RESULT_ORDER_ASC = 0; + private static final int RESULT_ORDER_DESC = 1; - public VectorSearchOptions() {} - - @SerializedName(value = "enableUseANN") private boolean enableUseANN = false; - - @SerializedName(value = "useIVFPQ") private boolean useIVFPQ = false; - @SerializedName(value = "vectorDistanceColumnName") - private String vectorDistanceColumnName = "vector_distance"; + private String distanceColumnName = ""; + private int distanceSlotId = 0; - @SerializedName(value = "vectorLimitK") - private long vectorLimitK; + private long limitK = 0; + private int resultOrder = 0; - @SerializedName(value = "queryVector") + private double predicateRange = -1; private List queryVector = new ArrayList<>(); - @SerializedName(value = "vectorRange") - private double vectorRange = -1; - - @SerializedName(value = "resultOrder") - private int resultOrder = 0; - public boolean isEnableUseANN() { return enableUseANN; } @@ -64,57 +51,56 @@ public void setUseIVFPQ(boolean useIVFPQ) { this.useIVFPQ = useIVFPQ; } - public String getVectorDistanceColumnName() { - return vectorDistanceColumnName; - } - - public void setVectorDistanceColumnName(String vectorDistanceColumnName) { - this.vectorDistanceColumnName = vectorDistanceColumnName; + public String getDistanceColumnName() { + return distanceColumnName; } - public long getVectorLimitK() { - return vectorLimitK; + public void setDistanceColumnName(String distanceColumnName) { + this.distanceColumnName = distanceColumnName; } - public void setVectorLimitK(long vectorLimitK) { - this.vectorLimitK = vectorLimitK; + public void setDistanceSlotId(int distanceSlotId) { + this.distanceSlotId = distanceSlotId; } - public List getQueryVector() { - return queryVector; + public void setLimitK(long limitK) { + this.limitK = limitK; } public void setQueryVector(List queryVector) { this.queryVector = queryVector; } - public double getVectorRange() { - return vectorRange; - } - - public void setVectorRange(double vectorRange) { - this.vectorRange = vectorRange; - } - - public int getResultOrder() { - return resultOrder; - } - - public void setResultOrder(int resultOrder) { - this.resultOrder = resultOrder; + public void setPredicateRange(double predicateRange) { + this.predicateRange = predicateRange; } - public static VectorSearchOptions read(String json) { - return GsonUtils.GSON.fromJson(json, VectorSearchOptions.class); + public void setResultOrder(boolean isAsc) { + this.resultOrder = isAsc ? RESULT_ORDER_ASC : RESULT_ORDER_DESC; } - public static Map readAnnParams(String json) { - Type type = new TypeToken>() {}.getType(); - return GsonUtils.GSON.fromJson(json, type); + public TVectorSearchOptions toThrift() { + TVectorSearchOptions opts = new TVectorSearchOptions(); + opts.setEnable_use_ann(true); + opts.setVector_limit_k(limitK); + opts.setVector_distance_column_name(distanceColumnName); + opts.setVector_slot_id(distanceSlotId); + opts.setQuery_vector(queryVector); + opts.setVector_range(predicateRange); + opts.setResult_order(resultOrder); + opts.setUse_ivfpq(useIVFPQ); + return opts; } - @Override - public String toString() { - return GsonUtils.GSON.toJson(this); + public String getExplainString(String prefix) { + return prefix + "VECTORINDEX: ON" + "\n" + + prefix + prefix + + "IVFPQ: " + (useIVFPQ ? "ON" : "OFF") + ", " + + "Distance Column: <" + distanceSlotId + ":" + distanceColumnName + ">, " + + "LimitK: " + limitK + ", " + + "Order: " + (resultOrder == RESULT_ORDER_ASC ? "ASC" : "DESC") + ", " + + "Query Vector: " + queryVector + ", " + + "Predicate Range: " + predicateRange + + "\n"; } } \ No newline at end of file diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java index 2cd4dce5ee23c..85dda6399c879 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java @@ -103,7 +103,6 @@ import com.starrocks.thrift.TScanRangeLocation; import com.starrocks.thrift.TScanRangeLocations; import com.starrocks.thrift.TTableSampleOptions; -import com.starrocks.thrift.TVectorSearchOptions; import com.starrocks.warehouse.Warehouse; import org.apache.commons.collections4.CollectionUtils; import org.apache.logging.log4j.LogManager; @@ -818,6 +817,14 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) output.append(prefix).append("SORT COLUMN: ").append(sortColumn).append("\n"); } + if (Config.enable_experimental_vector) { + if (vectorSearchOptions != null && vectorSearchOptions.isEnableUseANN()) { + output.append(vectorSearchOptions.getExplainString(prefix)); + } else { + output.append(prefix).append("VECTORINDEX: OFF").append("\n"); + } + } + if (detailLevel != TExplainLevel.VERBOSE) { if (isPreAggregation) { output.append(prefix).append("PREAGGREGATION: ON").append("\n"); @@ -825,13 +832,6 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) output.append(prefix).append("PREAGGREGATION: OFF. Reason: ").append(reasonOfPreAggregation) .append("\n"); } - if (ConnectContext.get() != null && Config.enable_experimental_vector == true) { - if (vectorSearchOptions != null && vectorSearchOptions.isEnableUseANN()) { - output.append(prefix).append("VECTORINDEX: ON").append("\n"); - } else { - output.append(prefix).append("VECTORINDEX: OFF").append("\n"); - } - } if (!conjuncts.isEmpty()) { output.append(prefix).append("PREDICATES: ").append( getExplainString(conjuncts)).append("\n"); @@ -1103,15 +1103,7 @@ protected void toThrift(TPlanNode msg) { } if (vectorSearchOptions != null && vectorSearchOptions.isEnableUseANN()) { - TVectorSearchOptions tVectorSearchOptions = new TVectorSearchOptions(); - tVectorSearchOptions.setEnable_use_ann(true); - tVectorSearchOptions.setVector_limit_k(vectorSearchOptions.getVectorLimitK()); - tVectorSearchOptions.setVector_distance_column_name(vectorSearchOptions.getVectorDistanceColumnName()); - tVectorSearchOptions.setQuery_vector(vectorSearchOptions.getQueryVector()); - tVectorSearchOptions.setVector_range(vectorSearchOptions.getVectorRange()); - tVectorSearchOptions.setResult_order(vectorSearchOptions.getResultOrder()); - tVectorSearchOptions.setUse_ivfpq(vectorSearchOptions.isUseIVFPQ()); - msg.olap_scan_node.setVector_search_options(tVectorSearchOptions); + msg.olap_scan_node.setVector_search_options(vectorSearchOptions.toThrift()); } msg.olap_scan_node.setUse_pk_index(usePkIndex); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index cdaeaf6071b71..8518719f5b4b4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -46,7 +46,6 @@ import com.starrocks.catalog.InternalCatalog; import com.starrocks.common.ErrorCode; import com.starrocks.common.ErrorReport; -import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.io.Text; import com.starrocks.common.io.Writable; import com.starrocks.common.util.CompressionUtils; @@ -54,6 +53,7 @@ import com.starrocks.connector.PlanMode; import com.starrocks.datacache.DataCachePopulateMode; import com.starrocks.monitor.unit.TimeValue; +import com.starrocks.persist.gson.GsonUtils; import com.starrocks.qe.VariableMgr.VarAttr; import com.starrocks.server.GlobalStateMgr; import com.starrocks.server.RunMode; @@ -85,6 +85,7 @@ import java.io.IOException; import java.io.Serializable; import java.lang.reflect.Field; +import java.lang.reflect.Type; import java.util.List; import java.util.Map; import java.util.Objects; @@ -2208,6 +2209,18 @@ public void setEnableParallelPrepareMetadata(boolean enableParallelPrepareMetada this.enableParallelPrepareMetadata = enableParallelPrepareMetadata; } + public void setAnnParams(String annParams) { + this.annParams = annParams; + } + + public Map getAnnParams() { + if (Strings.isNullOrEmpty(annParams)) { + return Maps.newHashMap(); + } + Type type = new com.google.gson.reflect.TypeToken>() {}.getType(); + return GsonUtils.GSON.fromJson(annParams, type); + } + public String getHiveTempStagingDir() { return hiveTempStagingDir; } @@ -4655,7 +4668,7 @@ public TQueryOptions toThrift() { tResult.setConnector_io_tasks_slow_io_latency_ms(connectorIoTasksSlowIoLatency); tResult.setConnector_scan_use_query_mem_ratio(connectorScanUseQueryMemRatio); tResult.setScan_use_query_mem_ratio(scanUseQueryMemRatio); - tResult.setAnn_params(VectorSearchOptions.readAnnParams(annParams)); + tResult.setAnn_params(getAnnParams()); tResult.setPq_refine_factor(pqRefineFactor); tResult.setK_factor(kFactor); tResult.setEnable_collect_table_level_scan_stats(enableCollectTableLevelScanStats); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java index e74fa33b1d9f4..ababf5b585c2c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java @@ -20,7 +20,6 @@ import com.google.common.collect.Sets; import com.starrocks.catalog.Database; import com.starrocks.catalog.ExternalOlapTable; -import com.starrocks.catalog.Index; import com.starrocks.catalog.KeysType; import com.starrocks.catalog.OlapTable; import com.starrocks.catalog.Table; @@ -31,9 +30,6 @@ import com.starrocks.common.ErrorCode; import com.starrocks.common.ErrorReport; import com.starrocks.common.LabelAlreadyUsedException; -import com.starrocks.common.VectorIndexParams.CommonIndexParamKey; -import com.starrocks.common.VectorIndexParams.VectorIndexType; -import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.profile.Timer; import com.starrocks.common.profile.Tracers; import com.starrocks.http.HttpConnectContext; @@ -52,7 +48,6 @@ import com.starrocks.sql.analyzer.SemanticException; import com.starrocks.sql.ast.DeleteStmt; import com.starrocks.sql.ast.DmlStmt; -import com.starrocks.sql.ast.IndexDef; import com.starrocks.sql.ast.InsertStmt; import com.starrocks.sql.ast.QueryRelation; import com.starrocks.sql.ast.QueryStatement; @@ -87,7 +82,6 @@ import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Set; @@ -137,14 +131,12 @@ public static ExecPlan plan(StatementBase stmt, ConnectContext session, boolean areTablesCopySafe = AnalyzerUtils.areTablesCopySafe(queryStmt); needWholePhaseLock = isLockFree(areTablesCopySafe, session) ? false : true; ExecPlan plan; - VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); if (needWholePhaseLock) { - plan = createQueryPlan(queryStmt, session, resultSinkType, vectorSearchOptions); + plan = createQueryPlan(queryStmt, session, resultSinkType); } else { long planStartTime = OptimisticVersion.generate(); unLock(plannerMetaLocker); - plan = createQueryPlanWithReTry(queryStmt, session, resultSinkType, plannerMetaLocker, - planStartTime, vectorSearchOptions); + plan = createQueryPlanWithReTry(queryStmt, session, resultSinkType, plannerMetaLocker, planStartTime); } setOutfileSink(queryStmt, plan); return plan; @@ -240,16 +232,14 @@ public static MVTransformerContext makeMVTransformerContext(SessionVariable sess private static ExecPlan createQueryPlan(StatementBase stmt, ConnectContext session, - TResultSinkType resultSinkType, - VectorSearchOptions vectorSearchOptions) { + TResultSinkType resultSinkType) { QueryStatement queryStmt = (QueryStatement) stmt; - checkVectorIndex(queryStmt, vectorSearchOptions); QueryRelation query = (QueryRelation) queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); // 1. Build Logical plan ColumnRefFactory columnRefFactory = new ColumnRefFactory(); LogicalPlan logicalPlan; - MVTransformerContext mvTransformerContext = makeMVTransformerContext(session.getSessionVariable()); + MVTransformerContext mvTransformerContext = makeMVTransformerContext(session.getSessionVariable()); try (Timer ignored = Tracers.watchScope("Transformer")) { // get a logicalPlan without inlining views @@ -270,8 +260,7 @@ private static ExecPlan createQueryPlan(StatementBase stmt, stmt, new PhysicalPropertySet(), new ColumnRefSet(logicalPlan.getOutputColumn()), - columnRefFactory, - vectorSearchOptions); + columnRefFactory); } try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { @@ -295,13 +284,10 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, ConnectContext session, TResultSinkType resultSinkType, PlannerMetaLocker plannerMetaLocker, - long planStartTime, - VectorSearchOptions vectorSearchOptions) { + long planStartTime) { QueryRelation query = queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); - checkVectorIndex(queryStmt, vectorSearchOptions); - // 1. Build Logical plan ColumnRefFactory columnRefFactory = new ColumnRefFactory(); boolean isSchemaValid = true; @@ -344,8 +330,7 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, queryStmt, new PhysicalPropertySet(), new ColumnRefSet(logicalPlan.getOutputColumn()), - columnRefFactory, - vectorSearchOptions); + columnRefFactory); } try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { @@ -378,36 +363,6 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, "schema of %s had been updated frequently during the plan generation", updatedTables); } - private static boolean checkAndSetVectorIndex(OlapTable olapTable, VectorSearchOptions vectorSearchOptions) { - for (Index index : olapTable.getIndexes()) { - if (index.getIndexType() == IndexDef.IndexType.VECTOR) { - Map indexProperties = index.getProperties(); - String indexType = indexProperties.get(CommonIndexParamKey.INDEX_TYPE.name().toLowerCase(Locale.ROOT)); - - if (VectorIndexType.IVFPQ.name().equalsIgnoreCase(indexType)) { - vectorSearchOptions.setUseIVFPQ(true); - } - - vectorSearchOptions.setEnableUseANN(true); - return true; - } - } - return false; - } - - private static void checkVectorIndex(QueryStatement queryStmt, VectorSearchOptions vectorSearchOptions) { - Set olapTables = Sets.newHashSet(); - AnalyzerUtils.copyOlapTable(queryStmt, olapTables); - boolean hasVectorIndex = false; - for (OlapTable olapTable : olapTables) { - if (checkAndSetVectorIndex(olapTable, vectorSearchOptions)) { - hasVectorIndex = true; - break; - } - } - vectorSearchOptions.setEnableUseANN(hasVectorIndex); - } - public static Set collectOriginalOlapTables(ConnectContext session, StatementBase queryStmt) { Set olapTables = Sets.newHashSet(); PlannerMetaLocker locker = new PlannerMetaLocker(session, queryStmt); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/CreateTableAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/CreateTableAnalyzer.java index c048f282491a9..0831477e5795c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/CreateTableAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/CreateTableAnalyzer.java @@ -97,7 +97,7 @@ public static void analyze(CreateTableStmt statement, ConnectContext context) { analyzeTemporaryTable(statement, context, catalogName, db, tableName); } else { if (GlobalStateMgr.getCurrentState().getMetadataMgr() - .tableExists(catalogName, tableNameObject.getDb(), tableName) && !statement.isSetIfNotExists()) { + .tableExists(catalogName, tableNameObject.getDb(), tableName) && !statement.isSetIfNotExists()) { ErrorReport.reportSemanticException(ErrorCode.ERR_TABLE_EXISTS_ERROR, tableName); } } @@ -734,6 +734,16 @@ public static void analyzeIndexDefs(CreateTableStmt statement) { List indexes = new ArrayList<>(); if (CollectionUtils.isNotEmpty(indexDefs)) { + List vectorIndexNames = indexDefs.stream() + .filter(indexDef -> indexDef.getIndexType() == IndexDef.IndexType.VECTOR) + .map(IndexDef::getIndexName) + .toList(); + if (vectorIndexNames.size() > 1) { + throw new SemanticException( + String.format("At most one vector index is allowed for a table, but %d were found: %s", + vectorIndexNames.size(), vectorIndexNames)); + } + Set distinct = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); Set> distinctCol = new HashSet<>(); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/SetStmtAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/SetStmtAnalyzer.java index 6e26967907bc1..678c6132d2e44 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/SetStmtAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/SetStmtAnalyzer.java @@ -25,6 +25,7 @@ import com.starrocks.analysis.StringLiteral; import com.starrocks.analysis.Subquery; import com.starrocks.catalog.ArrayType; +import com.starrocks.catalog.IndexParams; import com.starrocks.catalog.PrimitiveType; import com.starrocks.catalog.Type; import com.starrocks.common.ErrorCode; @@ -37,6 +38,7 @@ import com.starrocks.datacache.DataCachePopulateMode; import com.starrocks.monitor.unit.TimeValue; import com.starrocks.mysql.MysqlPassword; +import com.starrocks.persist.gson.GsonUtils; import com.starrocks.qe.ConnectContext; import com.starrocks.qe.GlobalVariable; import com.starrocks.qe.SessionVariable; @@ -65,6 +67,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; public class SetStmtAnalyzer { public static void analyze(SetStmt setStmt, ConnectContext session) { @@ -333,6 +336,24 @@ private static void analyzeSystemVariable(SystemVariable var) { } } + if (variable.equalsIgnoreCase(SessionVariable.ANN_PARAMS)) { + String annParams = resolvedExpression.getStringValue(); + if (!Strings.isNullOrEmpty(annParams)) { + Map annParamMap = null; + try { + java.lang.reflect.Type type = new com.google.gson.reflect.TypeToken>() {}.getType(); + annParamMap = GsonUtils.GSON.fromJson(annParams, type); + } catch (Exception e) { + throw new SemanticException(String.format("Unsupported ann_params: %s, " + + "It should be a Dict JSON string, each key and value of which is string", annParams)); + } + + for (Map.Entry entry : annParamMap.entrySet()) { + IndexParams.getInstance().checkParams(entry.getKey().toUpperCase(), entry.getValue()); + } + } + } + var.setResolvedExpression(resolvedExpression); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index 6a5fc3c550319..b89a99c43d59c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -20,7 +20,6 @@ import com.starrocks.analysis.JoinOperator; import com.starrocks.catalog.MaterializedView; import com.starrocks.catalog.OlapTable; -import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.profile.Timer; import com.starrocks.common.profile.Tracers; import com.starrocks.qe.ConnectContext; @@ -183,7 +182,7 @@ public OptExpression optimize(ConnectContext connectContext, ColumnRefSet requiredColumns, ColumnRefFactory columnRefFactory) { return optimize(connectContext, logicOperatorTree, null, null, requiredProperty, - requiredColumns, columnRefFactory, new VectorSearchOptions()); + requiredColumns, columnRefFactory); } public OptExpression optimize(ConnectContext connectContext, @@ -192,14 +191,11 @@ public OptExpression optimize(ConnectContext connectContext, StatementBase stmt, PhysicalPropertySet requiredProperty, ColumnRefSet requiredColumns, - ColumnRefFactory columnRefFactory, - VectorSearchOptions vectorSearchOptions) { + ColumnRefFactory columnRefFactory) { try { // prepare for optimizer prepare(connectContext, columnRefFactory, logicOperatorTree); - context.setVectorSearchOptions(vectorSearchOptions); - // prepare for mv rewrite prepareMvRewrite(connectContext, logicOperatorTree, columnRefFactory, requiredColumns); try (Timer ignored = Tracers.watchScope("MVTextRewrite")) { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java index b96b241d4ef8a..503766bf42172 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java @@ -19,7 +19,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.starrocks.catalog.OlapTable; -import com.starrocks.common.VectorSearchOptions; import com.starrocks.qe.ConnectContext; import com.starrocks.qe.SessionVariable; import com.starrocks.server.GlobalStateMgr; @@ -57,7 +56,7 @@ public class OptimizerContext { private TaskContext currentTaskContext; private final OptimizerConfig optimizerConfig; - private Set queryTables; + private Set queryTables; private long updateTableId = -1; @@ -85,8 +84,6 @@ public class OptimizerContext { // collect all LogicalOlapScanOperators in the query before any optimization private List allLogicalOlapScanOperators; - private VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); - @VisibleForTesting public OptimizerContext(Memo memo, ColumnRefFactory columnRefFactory) { this.memo = memo; @@ -316,12 +313,4 @@ public void setAllLogicalOlapScanOperators(List allScan public List getAllLogicalOlapScanOperators() { return allLogicalOlapScanOperators; } - - public void setVectorSearchOptions(VectorSearchOptions vectorSearchOptions) { - this.vectorSearchOptions = vectorSearchOptions; - } - - public VectorSearchOptions getVectorSearchOptions() { - return vectorSearchOptions; - } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/Projection.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/Projection.java index 30bb60c7dde89..7272d642c0cfe 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/Projection.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/Projection.java @@ -27,6 +27,7 @@ import java.util.Set; public class Projection { + // output column ref -> expression private final Map columnRefMap; // Used for common operator compute result reuse, we need to compute // common sub operators firstly in BE diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java index 1e4a60c409d45..6ebe37c66149f 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java @@ -13,30 +13,37 @@ // limitations under the License. package com.starrocks.sql.optimizer.rule.transformation; -import com.google.common.collect.Lists; +import com.google.common.base.Enums; +import com.google.common.base.Preconditions; import com.starrocks.analysis.BinaryType; +import com.starrocks.catalog.ArrayType; import com.starrocks.catalog.Column; -import com.starrocks.catalog.FunctionSet; +import com.starrocks.catalog.ColumnId; +import com.starrocks.catalog.Index; +import com.starrocks.catalog.OlapTable; import com.starrocks.catalog.Type; import com.starrocks.common.Config; +import com.starrocks.common.VectorIndexParams; import com.starrocks.common.VectorSearchOptions; +import com.starrocks.sql.analyzer.SemanticException; +import com.starrocks.sql.ast.IndexDef; import com.starrocks.sql.optimizer.OptExpression; import com.starrocks.sql.optimizer.OptimizerContext; -import com.starrocks.sql.optimizer.base.ColumnRefFactory; -import com.starrocks.sql.optimizer.base.Ordering; -import com.starrocks.sql.optimizer.operator.Operator; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.Projection; import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator; import com.starrocks.sql.optimizer.operator.pattern.Pattern; +import com.starrocks.sql.optimizer.operator.scalar.ArrayOperator; import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.CastOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.optimizer.rule.RuleType; +import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; import java.util.HashMap; @@ -46,9 +53,7 @@ import java.util.stream.Collectors; import static com.starrocks.analysis.BinaryType.GE; -import static com.starrocks.analysis.BinaryType.GT; import static com.starrocks.analysis.BinaryType.LE; -import static com.starrocks.analysis.BinaryType.LT; import static com.starrocks.catalog.FunctionSet.APPROX_COSINE_SIMILARITY; import static com.starrocks.catalog.FunctionSet.APPROX_L2_DISTANCE; @@ -62,224 +67,328 @@ public RewriteToVectorPlanRule() { @Override public boolean check(OptExpression input, OptimizerContext context) { - LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); - LogicalOlapScanOperator scanOperator = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); - - scanOperator.setVectorSearchOptions(context.getVectorSearchOptions()); - VectorSearchOptions vectorSearchOptions = scanOperator.getVectorSearchOptions(); - - if (!vectorSearchOptions.isEnableUseANN() || Config.enable_experimental_vector != true) { + if (!Config.enable_experimental_vector) { return false; } - Map columnRefMap = scanOperator.getProjection().getColumnRefMap(); + LogicalTopNOperator topNOp = (LogicalTopNOperator) input.getOp(); + LogicalOlapScanOperator scanOp = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); - boolean isEnableUseANN = false; - for (Map.Entry entry : columnRefMap.entrySet()) { - if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName()) - && entry.getValue() instanceof CallOperator) { - CallOperator callOperator = (CallOperator) entry.getValue(); - vectorSearchOptions.setQueryVector(collectVectorQuery(callOperator)); - isEnableUseANN = true; - break; - } + if (scanOp.getProjection() == null) { + return false; } - if (!isEnableUseANN) { - vectorSearchOptions.setEnableUseANN(false); + if (topNOp.getLimit() <= 0 || topNOp.getOrderByElements().size() != 1) { return false; } - if (!topNOperator.getOrderByElements().isEmpty() && - FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains( - topNOperator.getOrderByElements().get(0).getColumnRef().getName())) { - return topNOperator.getLimit() != Operator.DEFAULT_LIMIT && - columnRefMap.entrySet().stream() - .filter(entry -> FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) - .anyMatch(entry -> entry.getValue() instanceof CallOperator); - } - return false; + return true; } @Override public List transform(OptExpression input, OptimizerContext context) { - LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); - LogicalOlapScanOperator scanOperator = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); - - VectorSearchOptions options = scanOperator.getVectorSearchOptions(); - // set limit_K for ann searcher - options.setVectorLimitK(topNOperator.getLimit()); - ScalarOperator predicate = scanOperator.getPredicate(); - Optional newPredicate = Optional.empty(); - boolean isAscending = topNOperator.getOrderByElements().get(0).isAscending(); + LogicalTopNOperator topNOp = (LogicalTopNOperator) input.getOp(); + LogicalOlapScanOperator scanOp = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); + VectorSearchOptions opts = scanOp.getVectorSearchOptions(); + + Optional optionalInfo = extractOrderByVectorFuncInfo(topNOp, scanOp); + if (optionalInfo.isEmpty()) { + return List.of(); + } + VectorFuncInfo info = optionalInfo.get(); + + int dim = + Integer.parseInt(info.index.getProperties().get(VectorIndexParams.CommonIndexParamKey.DIM.name().toLowerCase())); + if (info.vectorQuery.size() != dim) { + throw new SemanticException( + String.format("The vector query size (%s) is not equal to the vector index dimension (%d)", + info.vectorQuery, dim)); + } + + ScalarOperator predicate = scanOp.getPredicate(); if (predicate != null) { - newPredicate = findAndSetVectorRange(predicate, isAscending, options); - if (!options.isEnableUseANN()) { - return Lists.newArrayList(input); + Optional value = extractVectorRange(predicate, info); + // If some predicates cannot be parsed to vector range, vector index cannot be used. + if (value.isEmpty()) { + return List.of(); } + // All the predicates are parsed to vector range, so remove predicates from scan operator. + predicate = null; + opts.setPredicateRange(value.get()); } - options.setResultOrder(isAscending ? 0 : 1); - String functionName = topNOperator.getOrderByElements().get(0).getColumnRef().getName(); - - if (functionName.equalsIgnoreCase(APPROX_L2_DISTANCE) && !isAscending || - functionName.equalsIgnoreCase(APPROX_COSINE_SIMILARITY) && isAscending || - !options.isEnableUseANN()) { - options.setEnableUseANN(false); - return Lists.newArrayList(input); - } - if (options.isUseIVFPQ()) { + + opts.setEnableUseANN(true); + String indexType = info.index.getProperties().get(VectorIndexParams.CommonIndexParamKey.INDEX_TYPE.name().toLowerCase()); + opts.setUseIVFPQ(VectorIndexParams.VectorIndexType.IVFPQ.name().equalsIgnoreCase(indexType)); + opts.setLimitK(topNOp.getLimit()); + opts.setResultOrder(info.isAscending); + opts.setDistanceColumnName("__vector_" + info.outColumnRef.getName()); + opts.setQueryVector(info.vectorQuery); + + if (opts.isUseIVFPQ()) { // Skip rewrite because IVFPQ is inaccurate and requires a brute force search after the ANN index search - input.getInputs().get(0).getOp() - .setPredicate(newPredicate.isPresent() ? newPredicate.get() : null); - return Lists.newArrayList(input); + LogicalOlapScanOperator newScanOp = LogicalOlapScanOperator.builder() + .withOperator(scanOp) + .setPredicate(predicate) + .build(); + return List.of(OptExpression.create(topNOp, OptExpression.create(newScanOp))); } - Optional result = buildVectorSortScanOperator(topNOperator, - scanOperator, context, newPredicate, options); - return result.isPresent() ? Lists.newArrayList(result.get()) : Lists.newArrayList(input); + return List.of(rewriteOptByDistanceColumn(topNOp, scanOp, context, predicate, info, opts)); } - public Optional buildVectorSortScanOperator(LogicalTopNOperator topNOperator, - LogicalOlapScanOperator scanOperator, OptimizerContext context, - Optional newPredicate, VectorSearchOptions vectorSearchOptions) { - // bottom-up - String distanceColumnName = scanOperator.getVectorSearchOptions().getVectorDistanceColumnName(); + private OptExpression rewriteOptByDistanceColumn(LogicalTopNOperator topNOp, + LogicalOlapScanOperator scanOp, + OptimizerContext context, + ScalarOperator newPredicate, + VectorFuncInfo info, + VectorSearchOptions opts) { + // Add index distanceColumn to the scan operator, including table, colRefToColumnMetaMap, and columnMetaToColRefMap. + String distanceColumnName = scanOp.getVectorSearchOptions().getDistanceColumnName(); Column distanceColumn = new Column(distanceColumnName, Type.FLOAT); - scanOperator.getTable().addColumn(distanceColumn); - - ColumnRefFactory columnRefFactory = context.getColumnRefFactory(); - ColumnRefOperator distanceColumnRefOperator = columnRefFactory.create(distanceColumnName, Type.FLOAT, false); - - Map colRefToColumnMetaMap = new HashMap<>(scanOperator.getColRefToColumnMetaMap()); - colRefToColumnMetaMap.put(distanceColumnRefOperator, distanceColumn); - - Map columnMetaToColRefMap = new HashMap<>(scanOperator.getColumnMetaToColRefMap()); - columnMetaToColRefMap.put(distanceColumn, distanceColumnRefOperator); - - // new Scan operator - LogicalOlapScanOperator newScanOperator = new LogicalOlapScanOperator(scanOperator.getTable(), - colRefToColumnMetaMap, columnMetaToColRefMap, scanOperator.getDistributionSpec(), - scanOperator.getLimit(), newPredicate.isPresent() ? newPredicate.get() : null, - scanOperator.getSelectedIndexId(), scanOperator.getSelectedPartitionId(), - scanOperator.getPartitionNames(), scanOperator.hasTableHints(), - scanOperator.getSelectedTabletId(), scanOperator.getHintsTabletIds(), - scanOperator.getHintsReplicaIds(), scanOperator.isUsePkIndex()); - - newScanOperator.setVectorSearchOptions(vectorSearchOptions); - Map scanProjectMap = new HashMap<>(); - Map topNProjectMap = new HashMap<>(); - // find original column and project it onto the topN - Optional originalColRef = scanOperator.getProjection().getColumnRefMap() - .entrySet().stream().filter(entry -> FunctionSet.VECTOR_COMPUTE_FUNCTIONS - .contains(entry.getKey().getName())).map(entry -> entry.getKey()) - .findFirst(); - if (originalColRef.isEmpty()) { - return Optional.empty(); + scanOp.getTable().addColumn(distanceColumn); + + ColumnRefOperator distanceColRef = context.getColumnRefFactory().create(distanceColumnName, Type.FLOAT, false); + Map newColRefToColumnMetaMap = new HashMap<>(scanOp.getColRefToColumnMetaMap()); + newColRefToColumnMetaMap.put(distanceColRef, distanceColumn); + + Map newColumnMetaToColRefMap = new HashMap<>(scanOp.getColumnMetaToColRefMap()); + newColumnMetaToColRefMap.put(distanceColumn, distanceColRef); + + opts.setDistanceSlotId(distanceColRef.getId()); + + // Replace the original function call by the distance column ref. + Map newScanProjectMap = scanOp.getProjection().getColumnRefMap().entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> rewriteScalarOperatorByDistanceColumn(entry.getValue(), info, distanceColRef) + )); + + LogicalOlapScanOperator newScanOp = LogicalOlapScanOperator.builder() + .withOperator(scanOp) + .setProjection(new Projection(newScanProjectMap)) + .setPredicate(newPredicate) + .setColRefToColumnMetaMap(newColRefToColumnMetaMap) + .setColumnMetaToColRefMap(newColumnMetaToColRefMap) + .build(); + + return OptExpression.create(topNOp, OptExpression.create(newScanOp)); + } + + ScalarOperator rewriteScalarOperatorByDistanceColumn(ScalarOperator scalarOperator, VectorFuncInfo info, + ColumnRefOperator distanceColRef) { + if (scalarOperator.equals(info.vectorFuncCallOperator)) { + return distanceColRef; + } + + for (int i = 0; i < scalarOperator.getChildren().size(); i++) { + ScalarOperator child = scalarOperator.getChild(i); + scalarOperator.setChild(i, rewriteScalarOperatorByDistanceColumn(child, info, distanceColRef)); + } + + return scalarOperator; + } + + /** + * Check if the operator matches the specific vector function call. + * + *

For example, assume that `vectorFuncCallOperator` is `approx_l2_distance(v1, [1,2,3])`, + * then the following operators match: + * - `approx_l2_distance(v1, [1,2,3])` + * - `cast(approx_l2_distance(v1, [1,2,3]) as float)` + * - `cast(approx_l2_distance(v1, [1,2,3]) as double)` + */ + private boolean matchesVectorFuncCall(CallOperator vectorFuncCallOperator, ScalarOperator operator) { + if (operator instanceof CastOperator) { + CastOperator castOperator = (CastOperator) operator; + return castOperator.getType().isFloatingPointType() && + matchesVectorFuncCall(vectorFuncCallOperator, castOperator.getChild(0)); + } + + if (operator instanceof CallOperator) { + return vectorFuncCallOperator.equals(operator); } - scanOperator.getProjection().getColumnRefMap().entrySet().stream() - .forEach(entry -> { - if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName()) - && entry.getValue() instanceof CallOperator) { - scanProjectMap.put(distanceColumnRefOperator, distanceColumnRefOperator); + return false; + } + + /** + * Extract the vector range from the predicates. + * + *

Only the predicates in the following format can be parsed to vector range: + * - req1: <=, >=, and one side is constant, the other side is the vector index column. + * - req2: AND, and each child predicate meets req1. + * + *

For example, suppose v1 is the vector index column and isAscending=true, then: + * - v1 <= 10: 10 + * - v1 <= 10 AND v1 < 20: 10 + * - v1 >= 10: cannot be parsed + * - c1 <= 10: cannot be parsed + * - v1 <= 10 and c1 < 10: cannot be parsed + * + * @return the vector range value if the predicate can be parsed to vector range, otherwise empty. + */ + private Optional extractVectorRange(ScalarOperator predicate, VectorFuncInfo info) { + if (predicate instanceof BinaryPredicateOperator) { + return parseVectorRangeFromBinaryPredicate(predicate, info); + } else if (predicate instanceof CompoundPredicateOperator) { + CompoundPredicateOperator compoundPredicate = (CompoundPredicateOperator) predicate; + if (!compoundPredicate.isAnd()) { + return Optional.empty(); + } + Optional value = Optional.empty(); + for (ScalarOperator child : predicate.getChildren()) { + Optional childValue = parseVectorRangeFromBinaryPredicate(child, info); + if (childValue.isEmpty()) { + return Optional.empty(); + } + if (value.isEmpty()) { + value = childValue; + } else { + if (info.isAscending) { + value = Optional.of(Math.min(value.get(), childValue.get())); } else { - scanProjectMap.put(entry.getKey(), entry.getValue()); - topNProjectMap.put(entry.getKey(), entry.getValue()); + value = Optional.of(Math.max(value.get(), childValue.get())); } - }); - newScanOperator.setProjection(new Projection(scanProjectMap)); - - List orderByElements = topNOperator.getOrderByElements().stream().map(ordering -> - FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(ordering.getColumnRef().getName()) ? - new Ordering(distanceColumnRefOperator, ordering.isAscending(), ordering.isNullsFirst()) : ordering) - .collect(Collectors.toList()); - - boolean hasProjection = topNOperator.getProjection() != null; - Map newTopNProjectMap = new HashMap<>(); - if (hasProjection) { - topNOperator.getProjection().getColumnRefMap().entrySet().stream() - .forEach(entry -> { - if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) { - newTopNProjectMap.put(originalColRef.get(), distanceColumnRefOperator); - } else { - newTopNProjectMap.put(entry.getKey(), entry.getValue()); - } - }); - } else { - topNProjectMap.put(originalColRef.get(), distanceColumnRefOperator); + } + } + return value; } - // new TopN operator - LogicalTopNOperator newTopNOperator = new LogicalTopNOperator(topNOperator.getLimit(), - topNOperator.getPredicate(), - hasProjection ? new Projection(newTopNProjectMap) : new Projection(topNProjectMap), - topNOperator.getPartitionByColumns(), topNOperator.getPartitionLimit(), orderByElements, - topNOperator.getOffset(), topNOperator.getSortPhase(), topNOperator.getTopNType(), topNOperator.isSplit()); + return Optional.empty(); + } - OptExpression topNExpression = OptExpression.create(newTopNOperator); - topNExpression.getInputs().clear(); - topNExpression.getInputs().add(OptExpression.create(newScanOperator)); + private Optional parseVectorRangeFromBinaryPredicate(ScalarOperator predicate, VectorFuncInfo info) { + if (predicate instanceof BinaryPredicateOperator) { + BinaryType binaryType = ((BinaryPredicateOperator) predicate).getBinaryType(); + ScalarOperator lhs = predicate.getChild(0); + ScalarOperator rhs = predicate.getChild(1); + + if (rhs instanceof ConstantOperator && matchesVectorFuncCall(info.vectorFuncCallOperator, lhs) && + (((binaryType.equals(LE)) && info.isAscending) || ((binaryType.equals(GE)) && !info.isAscending))) { + return Optional.of((double) ((ConstantOperator) rhs).getValue()); + } else if (lhs instanceof ConstantOperator && matchesVectorFuncCall(info.vectorFuncCallOperator, rhs) && + (((binaryType.equals(GE)) && info.isAscending) || ((binaryType.equals(LE)) && !info.isAscending))) { + return Optional.of((double) ((ConstantOperator) lhs).getValue()); + } + } - return Optional.of(topNExpression); + return Optional.empty(); } - public Optional findAndSetVectorRange(ScalarOperator operator, - boolean isAscending, VectorSearchOptions options) { - if (!options.isEnableUseANN()) { + /** + * Extract the vector function information. If the vector index can be used, the following requirements need to be met: + * 1. The first column of the ordering is the function. + * 2. The function needs to match the metric_type and isAscending of the vector index. + * - If the metric_type is L2_DISTANCE, then the function is approx_l2_distance, and the order is ASC. + * - If the metric_type is COSINE_SIMILARITY, then the function is cosine_similarity, and the order is DESC. + * 3. The arguments of the function are the vector index column and a constant array. + * + * @return the vector function information if the ordering column is matched, otherwise empty. + */ + private Optional extractOrderByVectorFuncInfo(LogicalTopNOperator topNOp, LogicalOlapScanOperator scanOp) { + OlapTable table = (OlapTable) scanOp.getTable(); + Index index = table.getIndexes().stream() + .filter(i -> i.getIndexType() == IndexDef.IndexType.VECTOR) + .findFirst() + .orElse(null); + if (index == null) { return Optional.empty(); } - if (operator instanceof BinaryPredicateOperator && operator.getChild(1) instanceof ConstantOperator && - (isVectorCallOperator(operator.getChild(0)))) { - BinaryType binaryType = ((BinaryPredicateOperator) operator).getBinaryType(); - if (((binaryType.equals(LE) || binaryType.equals(LT)) && !isAscending) || - ((binaryType.equals(GE) || binaryType.equals(GT)) && isAscending)) { - options.setEnableUseANN(false); - return Optional.empty(); - } - options.setVectorRange((double) (((ConstantOperator) operator.getChild(1)).getValue())); + ColumnRefOperator outColRef = topNOp.getOrderByElements().get(0).getColumnRef(); + final boolean isAscending = topNOp.getOrderByElements().get(0).isAscending(); + + String rawMetricType = index.getProperties().get(VectorIndexParams.CommonIndexParamKey.METRIC_TYPE.name().toLowerCase()); + VectorIndexParams.MetricsType metricType = + Enums.getIfPresent(VectorIndexParams.MetricsType.class, StringUtils.upperCase(rawMetricType)).orNull(); + Preconditions.checkNotNull(metricType, "Invalid metric type [" + rawMetricType + "] for vector index"); + + // 1. Check: it is a matched vector function. + ScalarOperator inOperator = scanOp.getProjection().getColumnRefMap().get(outColRef); + if (!(inOperator instanceof CallOperator)) { return Optional.empty(); - } else if (operator instanceof CompoundPredicateOperator) { - List newOperators = new ArrayList<>(); - for (ScalarOperator child : operator.getChildren()) { - Optional newChild = findAndSetVectorRange(child, isAscending, options); - if (newChild.isPresent()) { - newOperators.add(newChild.get()); - } - } - if (newOperators.size() > 1) { - return Optional.of(new CompoundPredicateOperator(((CompoundPredicateOperator) operator).getCompoundType(), - newOperators)); - } else if (newOperators.size() == 1) { - return Optional.of(newOperators.get(0)); - } else { - return Optional.empty(); - } + } + CallOperator inCallOperator = (CallOperator) inOperator; + + boolean matchedFunc; + switch (metricType) { + case L2_DISTANCE: + matchedFunc = inCallOperator.getFnName().equalsIgnoreCase(APPROX_L2_DISTANCE) && isAscending; + break; + case COSINE_SIMILARITY: + matchedFunc = inCallOperator.getFnName().equalsIgnoreCase(APPROX_COSINE_SIMILARITY) && !isAscending; + break; + default: + matchedFunc = false; + } + if (!matchedFunc) { + return Optional.empty(); + } + + // 2. Check: the vector function's arguments are column ref and constant. + ScalarOperator lhs = inCallOperator.getChild(0); + ScalarOperator rhs = inCallOperator.getChild(1); + ColumnRefOperator colRefArgument; + if (isConstantArrayFloat(lhs) && rhs.isColumnRef()) { + colRefArgument = (ColumnRefOperator) rhs; + } else if (isConstantArrayFloat(rhs) && lhs.isColumnRef()) { + colRefArgument = (ColumnRefOperator) lhs; } else { - options.setEnableUseANN(false); - return Optional.of(operator.clone()); + return Optional.empty(); } - } - public boolean isVectorCallOperator(ScalarOperator scalarOperator) { - if (scalarOperator instanceof CallOperator && - FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(((CallOperator) scalarOperator).getFnName())) { - return true; + // 3. Check: the column ref argument of the vector function matches the index column. + Column column = scanOp.getColRefToColumnMetaMap().get(colRefArgument); + if (column == null) { + return Optional.empty(); } - if (scalarOperator.getChildren().size() == 0) { - return false; + + ColumnId indexColumnId = index.getColumns().get(0); + if (!column.getColumnId().equals(indexColumnId)) { + return Optional.empty(); } - return isVectorCallOperator(scalarOperator.getChild(0)); - } - public List collectVectorQuery(CallOperator callOperator) { - // suppose it's a standard vector query + // 4. Parse query vector values. List vectorQuery = new ArrayList<>(); - collectVector(callOperator, vectorQuery); - return vectorQuery; + extractValuesFromConstantArray(inCallOperator, vectorQuery); + + return Optional.of( + new VectorFuncInfo(index, colRefArgument, outColRef, inCallOperator, metricType, vectorQuery, isAscending)); + } + + /** + * Whether the scalar operator is a constant array of float, which is represented as + * `ArrayOperator(type=ArrayType(float))` or + * `CastOperator(child=ArrayOperator(type=ArrayType(numeric_type)), type=ArrayType(float))`. + */ + private boolean isConstantArrayFloat(ScalarOperator scalarOperator) { + if (!scalarOperator.isConstant()) { + return false; + } + + if (scalarOperator instanceof CastOperator) { + if (!scalarOperator.getType().isArrayType()) { + return false; + } + ArrayType arrayType = (ArrayType) scalarOperator.getType(); + if (!arrayType.getItemType().isFloatingPointType()) { + return false; + } + + return scalarOperator.getChildren().stream().allMatch(this::isConstantArrayFloat); + } else if (scalarOperator instanceof ArrayOperator) { + if (!scalarOperator.getType().isArrayType()) { + return false; + } + ArrayType innerArrayType = (ArrayType) scalarOperator.getType(); + return innerArrayType.getItemType().isNumericType(); + } else { + return false; + } } - public void collectVector(ScalarOperator scalarOperator, List vectorQuery) { + private void extractValuesFromConstantArray(ScalarOperator scalarOperator, List vectorQuery) { if (scalarOperator instanceof ColumnRefOperator) { return; } @@ -290,7 +399,34 @@ public void collectVector(ScalarOperator scalarOperator, List vectorQuer } for (ScalarOperator child : scalarOperator.getChildren()) { - collectVector(child, vectorQuery); + extractValuesFromConstantArray(child, vectorQuery); + } + } + + private static class VectorFuncInfo { + private final Index index; + // vector index column + private final ColumnRefOperator inColumnRef; + // The column ref of the first ordering column, which is obtained by vectorFuncCallOperator `(inColumnRef, vectorQuery)`. + // - If metricType is L2_DISTANCE, then function is `approx_l2_distance`, and the order is ASC. + // - If metricType is COSINE_SIMILARITY, then function `is cosine_similarity`, and the order is DESC. + private final ColumnRefOperator outColumnRef; + private final CallOperator vectorFuncCallOperator; + private final VectorIndexParams.MetricsType metricType; + // The constant vector argument value of the function + private final List vectorQuery; + private final boolean isAscending; + + public VectorFuncInfo(Index index, ColumnRefOperator inColumnRef, ColumnRefOperator outColumnRef, + CallOperator vectorFuncCallOperator, + VectorIndexParams.MetricsType metricType, List vectorQuery, boolean isAscending) { + this.index = index; + this.inColumnRef = inColumnRef; + this.outColumnRef = outColumnRef; + this.vectorFuncCallOperator = vectorFuncCallOperator; + this.metricType = metricType; + this.vectorQuery = vectorQuery; + this.isAscending = isAscending; } } } diff --git a/fe/fe-core/src/test/java/com/starrocks/analysis/VectorIndexTest.java b/fe/fe-core/src/test/java/com/starrocks/analysis/VectorIndexTest.java index 43e4513c67f59..7f6ce33ca0684 100644 --- a/fe/fe-core/src/test/java/com/starrocks/analysis/VectorIndexTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/analysis/VectorIndexTest.java @@ -86,7 +86,7 @@ public void testCheckVectorIndex() { () -> VectorIndexUtil.checkVectorIndexValid(c3, Collections.emptyMap(), KeysType.DUP_KEYS), "You should set index_type at least to add a vector index."); - Column c4 = new Column("f4", Type.ARRAY_FLOAT, true); + Column c4 = new Column("f4", Type.ARRAY_FLOAT, false); Assertions.assertThrows( SemanticException.class, () -> VectorIndexUtil.checkVectorIndexValid(c4, new HashMap<>() {{ @@ -148,8 +148,7 @@ public void testCheckVectorIndex() { put(VectorIndexParams.IndexParamsKey.M.name(), "10"); put(VectorIndexParams.IndexParamsKey.EFCONSTRUCTION.name(), "10"); put(VectorIndexParams.SearchParamsKey.EFSEARCH.name(), "10"); - }}, KeysType.DUP_KEYS), - "Params HNSW should not define with NBITS" + }}, KeysType.DUP_KEYS) ); Map paramItemMap = new HashMap<>(){{ diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java index 152504daebda7..135df9e5407a8 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java @@ -36,10 +36,13 @@ import com.starrocks.common.Config; import com.starrocks.common.FeConstants; +import com.starrocks.sql.analyzer.SemanticException; import com.starrocks.sql.plan.PlanTestBase; import org.junit.BeforeClass; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + public class VectorIndexTest extends PlanTestBase { @BeforeClass @@ -49,7 +52,9 @@ public static void beforeClass() throws Exception { FeConstants.enablePruneEmptyOutputScan = false; starRocksAssert.withTable("CREATE TABLE test.test_cosine (" + " c0 INT," - + " c1 array," + + " c1 array NOT NULL," + + " c2 array," + + " vector_distance float," + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'cosine_similarity', " + "'is_vector_normed' = 'false', 'M' = '512', 'index_type' = 'hnsw', 'dim'='5') " + ") " @@ -59,7 +64,8 @@ public static void beforeClass() throws Exception { starRocksAssert.withTable("CREATE TABLE test.test_l2 (" + " c0 INT," - + " c1 array," + + " c1 array NOT NULL," + + " c2 array," + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'l2_distance', " + "'is_vector_normed' = 'false', 'M' = '512', 'index_type' = 'hnsw', 'dim'='5') " + ") " @@ -69,53 +75,481 @@ public static void beforeClass() throws Exception { starRocksAssert.withTable("CREATE TABLE test.test_ivfpq (" + " c0 INT," - + " c1 array," + + " c1 array NOT NULL," + + " c2 array," + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'l2_distance', " - + "'is_vector_normed' = 'false', 'nbits' = '1', 'index_type' = 'ivfpq', 'dim'='5') " + + "'is_vector_normed' = 'false', 'nbits' = '8', 'index_type' = 'ivfpq', 'dim'='4', 'm_ivfpq'='2') " + + ") " + + "DUPLICATE KEY(c0) " + + "DISTRIBUTED BY HASH(c0) BUCKETS 1 " + + "PROPERTIES ('replication_num'='1');"); + + starRocksAssert.withTable("CREATE TABLE test.test_no_vector_index (" + + " c0 INT," + + " c1 array," + + " c2 array" + ") " + "DUPLICATE KEY(c0) " + "DISTRIBUTED BY HASH(c0) BUCKETS 1 " + "PROPERTIES ('replication_num'='1');"); } + @Test + public void testMeetOrderByRequirement() throws Exception { + String sql; + String plan; + + // Basic cases. + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " 2:TOP-N\n" + + " | order by: [5, FLOAT, false] DESC\n" + + " | build runtime filters:\n" + + " | - filter_id = 0, build_expr = ( 5: approx_cosine_similarity), remote = false\n" + + " | offset: 0\n" + + " | limit: 10\n" + + " | cardinality: 1\n" + + " | \n" + + " 1:Project\n" + + " | output columns:\n" + + " | 2 <-> [2: c1, ARRAY, false]\n" + + " | 5 <-> [7: __vector_approx_cosine_similarity, FLOAT, false]\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_cosine, rollup: test_cosine\n" + + " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: -1.0"); + + sql = "select c1 from test_l2 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: -1.0"); + + // Constant vector with cast. + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([cast(1.1 as double),cast(2.1 as double)," + + "cast(3.1 as double),cast(4.1 as double),cast(5.1 as double)], c1) desc " + + "limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.1, 3.1, 4.1, 5.1], Predicate Range: -1.0"); + + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([cast(1.1 as float),cast(2.1 as float),cast(3.1 as float)" + + ",cast(4.1 as float),cast(5.1 as float)], c1) desc " + + "limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.1, 3.1, 4.1, 5.1], Predicate Range: -1.0"); + + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([cast(1.1 as int),cast(2.1 as int),cast(3.1 as int)" + + ",cast(4.1 as int),cast(5.1 as int)], c1) desc " + + "limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.1, 3.1, 4.1, 5.1], Predicate Range: -1.0"); + } + + @Test + public void testNotMeetOrderByRequirement() throws Exception { + String sql; + String plan; + + // Wrong function name. + sql = "select c1 from test_l2 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + sql = "select c1 from test_cosine " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Wrong column ref. + sql = "select c1 from test_l2 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c2) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Wrong constant vector + sql = "select c1 from test_l2 " + + "order by approx_l2_distance(['a', 'b', 'c'], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + sql = "select c1 from test_l2 " + + "order by approx_l2_distance(c2, c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Wrong ASC/DESC + sql = "select c1 from test_l2 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) DESC limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // No limit. + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) DESC"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + } + + @Test + public void testMeetPredicateRequirement() throws Exception { + String sql; + String plan; + + // Basic cases. + sql = "select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + // Cast + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= cast(100 as double) " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= cast(100 as int) " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= cast(100 as float) " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + // AND + sql = "select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 1000 " + + "and approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 1000.0"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 and approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 1000 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + } + + @Test + public void testNotMeetPredicateRequirement() throws Exception { + String sql; + String plan; + + // Predicate direction wrong. + sql = "select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) <= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Must >=, <=, not >, <. + sql = "select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) > 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) < 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Column ref is not vector column. + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c2) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // constant vector is not the same. + sql = "select c1 from test_l2 " + + "where approx_l2_distance([10,2.2,3.3], c2) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Cannot deal with approx_l2_distance with other functions. + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) * 2 <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // Cannot deal with approx_l2_distance with other predicates. + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 and c0 < 10 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + + // OR + sql = "select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 or approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 1000 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, "VECTORINDEX: OFF"); + } + + @Test + public void testRewrite() throws Exception { + String sql; + String plan; + + sql = "select c1, " + + "approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1)+1, " + + "approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1)+2, " + + "cast(approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) as string), " + + "approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c2)+2 " + + "from test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " 2:TOP-N\n" + + " | order by: [9, FLOAT, false] DESC\n" + + " | build runtime filters:\n" + + " | - filter_id = 0, build_expr = ( 9: approx_cosine_similarity), remote = false\n" + + " | offset: 0\n" + + " | limit: 10\n" + + " | cardinality: 1\n" + + " | \n" + + " 1:Project\n" + + " | output columns:\n" + + " | 2 <-> [2: c1, ARRAY, false]\n" + + " | 5 <-> [13: cast, DOUBLE, true] + 1.0\n" + + " | 6 <-> [13: cast, DOUBLE, true] + 2.0\n" + + " | 7 <-> cast([12: __vector_approx_cosine_similarity, FLOAT, false] as VARCHAR(65533))\n" + + " | 8 <-> cast(approx_cosine_similarity[(cast([1.1,2.2,3.3,4.4,5.5] as ARRAY), [3: c2, ARRAY, true]); " + + "args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true] as DOUBLE) + 2.0\n" + + " | 9 <-> [12: __vector_approx_cosine_similarity, FLOAT, false]\n" + + " | common expressions:\n" + + " | 13 <-> cast([12: __vector_approx_cosine_similarity, FLOAT, false] as DOUBLE)\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_cosine, rollup: test_cosine\n" + + " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <12:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: -1.0"); + } + + @Test + public void testArgumentOrder() throws Exception { + String sql; + String plan; + + // Vector function argument order doesn't matter. + sql = "select c1 from test_cosine " + + "order by approx_cosine_similarity(c1, [1.1,2.2,3.3,4.4,5.5]) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: -1.0"); + + sql = "select c1 from test_l2 " + + "order by approx_l2_distance(c1, [1.1,2.2,3.3,4.4,5.5]) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: -1.0"); + + // Predicate argument order doesn't matter. + sql = "select c1 from test_cosine " + + "where 100 <= approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <7:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + + sql = "select c1 from test_l2 " + + "where 100 >= approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; + plan = getVerboseExplain(sql); + assertContains(plan, " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <6:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + } + + @Test + public void testMultipleTables() throws Exception { + String sql; + String plan; + + sql = "(select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10) " + + "UNION ALL " + + "(select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10) " + + "UNION ALL " + + "(select c1 from test_cosine " + + "where approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) >= 100 " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) limit 10) " + + "UNION ALL " + + "(select c1 from test_l2 " + + "where approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) <= 100 " + + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) DESC limit 10) " + + "UNION ALL " + + "(select c1 from test_no_vector_index)"; + plan = getVerboseExplain(sql); + System.out.println(plan); + assertContains(plan, " 1:OlapScanNode\n" + + " table: test_cosine, rollup: test_cosine\n" + + " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <24:__vector_approx_cosine_similarity>, LimitK: 10, Order: DESC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + assertContains(plan, " 7:OlapScanNode\n" + + " table: test_l2, rollup: test_l2\n" + + " VECTORINDEX: ON\n" + + " IVFPQ: OFF, Distance Column: <23:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, " + + "Query Vector: [1.1, 2.2, 3.3, 4.4, 5.5], Predicate Range: 100.0"); + assertContains(plan, " 13:OlapScanNode\n" + + " table: test_cosine, rollup: test_cosine\n" + + " VECTORINDEX: OFF"); + assertContains(plan, " 25:OlapScanNode\n" + + " table: test_no_vector_index, rollup: test_no_vector_index\n" + + " VECTORINDEX: OFF"); + } + + @Test + public void testQueryVectorDimNotMatch() throws Exception { + String sql = "select c1 from test.test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4], c1) desc limit 10"; + assertThatThrownBy(() -> getVerboseExplain(sql)) + .isInstanceOf(SemanticException.class) + .hasMessageContaining( + "The vector query size ([1.1, 2.2, 3.3, 4.4]) is not equal to the vector index dimension (5)"); + } + + @Test + public void testIvfpq() throws Exception { + String sql = "select c1, approx_l2_distance([1.1,2.2,3.3,4.4], c1) as score" + + " from test_ivfpq order by score limit 10"; + String plan = getVerboseExplain(sql); + assertContains(plan, " 2:TOP-N\n" + + " | order by: [4, FLOAT, true] ASC\n" + + " | build runtime filters:\n" + + " | - filter_id = 0, build_expr = ( 4: approx_l2_distance), remote = false\n" + + " | offset: 0\n" + + " | limit: 10\n" + + " | cardinality: 1\n" + + " | \n" + + " 1:Project\n" + + " | output columns:\n" + + " | 2 <-> [2: c1, ARRAY, false]\n" + + " | 4 <-> approx_l2_distance[(cast([1.1,2.2,3.3,4.4] as ARRAY), [2: c1, ARRAY, false]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_ivfpq, rollup: test_ivfpq\n" + + " VECTORINDEX: ON\n" + + " IVFPQ: ON, Distance Column: <0:__vector_approx_l2_distance>, LimitK: 10, Order: ASC, Query Vector: [1.1, 2.2, 3.3, 4.4], Predicate Range: -1.0"); + } + @Test public void testVectorIndexSyntax() throws Exception { String sql1 = "select c1 from test.test_cosine " + - "order by approx_cosine_similarity([1.1,2.2,3.3], c1) desc limit 10"; + "order by approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; assertPlanContains(sql1, "VECTORINDEX: ON"); String sql2 = "select c1 from test.test_l2 " + - "order by approx_l2_distance([1.1,2.2,3.3], c1) limit 10"; + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) limit 10"; assertPlanContains(sql2, "VECTORINDEX: ON"); // Sorting in desc order doesn't make sense in l2_distance, // which won't trigger the vector retrieval logic. String sql3 = "select c1 from test.test_l2 " + - "order by approx_l2_distance([1.1,2.2,3.3], c1) desc limit 10"; + "order by approx_l2_distance([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; assertPlanContains(sql3, "VECTORINDEX: OFF"); String sql4 = "select c1 from test.test_cosine " + - "order by cosine_similarity([1.1,2.2,3.3], c1) desc limit 10"; + "order by cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) desc limit 10"; assertPlanContains(sql4, "VECTORINDEX: OFF"); - String sql5 = "select c1, approx_l2_distance([1.1,2.2,3.3], c1) as score" + String sql5 = "select c1, approx_l2_distance([1.1,2.2,3.3,4.4], c1) as score" + " from test.test_ivfpq order by score limit 10"; assertPlanContains(sql5, "VECTORINDEX: ON"); - String sql6 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + String sql6 = "select c1, approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) as score" + " from test.test_cosine order by score desc limit 10"; assertPlanContains(sql6, "VECTORINDEX: ON"); - String sql7 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + String sql7 = "select c1, approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) as score" + " from test.test_cosine where c0 = 1 order by score desc limit 10"; assertPlanContains(sql7, "VECTORINDEX: OFF"); - String sql8 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" - + " from test.test_cosine having score > 0.8 order by score desc limit 10"; + String sql8 = "select c1, approx_cosine_similarity([1.1,2.2,3.3,4.4,5.5], c1) as score" + + " from test.test_cosine having score >= cast(0.8 as float) order by score desc limit 10"; assertPlanContains(sql8, "VECTORINDEX: ON"); - - String sql9 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" - + " from test.test_cosine having score < 0.8 order by score desc limit 10"; - assertPlanContains(sql9, "VECTORINDEX: OFF"); } + } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeAlterTableStatementTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeAlterTableStatementTest.java index 0162e6273151d..81b1ed9167a31 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeAlterTableStatementTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeAlterTableStatementTest.java @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - package com.starrocks.sql.analyzer; import com.google.common.collect.Lists; @@ -76,7 +75,7 @@ public void testNoClause() { } @Test(expected = SemanticException.class) - public void testCompactionClause() { + public void testCompactionClause() { new MockUp() { @Mock public RunMode getCurrentRunMode() { diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeSetVariableTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeSetVariableTest.java index 6a33cf9430298..cb60ce2971990 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeSetVariableTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeSetVariableTest.java @@ -18,6 +18,7 @@ import com.starrocks.analysis.Subquery; import com.starrocks.catalog.ResourceGroupMgr; import com.starrocks.qe.ConnectContext; +import com.starrocks.qe.SessionVariable; import com.starrocks.qe.SetExecutor; import com.starrocks.server.GlobalStateMgr; import com.starrocks.sql.ast.SetPassVar; @@ -27,6 +28,7 @@ import com.starrocks.thrift.TWorkGroup; import com.starrocks.utframe.StarRocksAssert; import com.starrocks.utframe.UtFrameUtils; +import com.uber.m3.util.ImmutableMap; import mockit.Expectations; import org.junit.Assert; import org.junit.BeforeClass; @@ -35,9 +37,12 @@ import static com.starrocks.sql.analyzer.AnalyzeTestUtil.analyzeFail; import static com.starrocks.sql.analyzer.AnalyzeTestUtil.analyzeSetUserVariableFail; import static com.starrocks.sql.analyzer.AnalyzeTestUtil.analyzeSuccess; +import static com.starrocks.sql.analyzer.AnalyzeTestUtil.connectContext; +import static org.assertj.core.api.Assertions.assertThat; public class AnalyzeSetVariableTest { private static StarRocksAssert starRocksAssert; + @BeforeClass public static void beforeClass() throws Exception { UtFrameUtils.createMinStarRocksCluster(); @@ -212,6 +217,7 @@ public void testSetResourceGroupName() { mgr.chooseResourceGroupByName(rg1Name); result = rg1; } + { mgr.chooseResourceGroupByName(anyString); result = null; @@ -241,6 +247,7 @@ public void testSetResourceGroupID() { mgr.chooseResourceGroupByID(rg1ID); result = rg1; } + { mgr.chooseResourceGroupByID(anyLong); result = null; @@ -310,4 +317,41 @@ public void testComputationFragmentSchedulingPolicy() { sql = "SET computation_fragment_scheduling_policy = compute_nodes"; analyzeFail(sql); } + + @Test + public void testSetAnnParams() { + SessionVariable sv = connectContext.getSessionVariable(); + String sql; + + sql = "set ann_params='invalid-format'"; + analyzeFail(sql, + "Unsupported ann_params: invalid-format, " + + "It should be a Dict JSON string, each key and value of which is string"); + + sql = "set ann_params='{\"Efsearch\": [1,2,3]}'"; + analyzeFail(sql, + "Unsupported ann_params: {\"Efsearch\": [1,2,3]}, " + + "It should be a Dict JSON string, each key and value of which is string"); + + sql = "set ann_params='{\"invalid-key\":\"abc\"}'"; + analyzeFail(sql, "Unknown index param: `INVALID-KEY"); + + sql = "set ann_params='{\"Efsearch\": 0}'"; + analyzeFail(sql, "Value of `EFSEARCH` must be >= 1"); + + sql = "set ann_params='{}'"; + analyzeSuccess(sql); + sv.setAnnParams("{}"); + assertThat(connectContext.getSessionVariable().getAnnParams()).isEmpty(); + + sql = "set ann_params=''"; + analyzeSuccess(sql); + sv.setAnnParams(""); + assertThat(connectContext.getSessionVariable().getAnnParams()).isEmpty(); + + sql = "set ann_params='{\"Efsearch\": 1}'"; + analyzeSuccess(sql); + sv.setAnnParams("{\"Efsearch\": 1}"); + assertThat(connectContext.getSessionVariable().getAnnParams()).containsExactlyEntriesOf(ImmutableMap.of("Efsearch", "1")); + } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeVectorIndexDMLTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeVectorIndexDMLTest.java new file mode 100644 index 0000000000000..42fe06aaab5e5 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/analyzer/AnalyzeVectorIndexDMLTest.java @@ -0,0 +1,625 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.analyzer; + +import com.starrocks.common.Config; +import com.starrocks.qe.ConnectContext; +import com.starrocks.qe.QueryState; +import com.starrocks.qe.StmtExecutor; +import com.starrocks.sql.ast.StatementBase; +import com.starrocks.sql.parser.SqlParser; +import com.starrocks.utframe.UtFrameUtils; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; + +import static com.starrocks.sql.analyzer.AnalyzeTestUtil.analyzeFail; +import static com.starrocks.sql.analyzer.AnalyzeTestUtil.analyzeSuccess; +import static org.assertj.core.api.Assertions.assertThat; + +public class AnalyzeVectorIndexDMLTest { + private static ConnectContext connectContext; + + @BeforeClass + public static void beforeClass() throws Exception { + UtFrameUtils.createMinStarRocksCluster(); + AnalyzeTestUtil.init(); + UtFrameUtils.addMockBackend(10002); + UtFrameUtils.addMockBackend(10003); + connectContext = AnalyzeTestUtil.getConnectContext(); + + Config.enable_experimental_vector = true; + } + + @Test + public void testValidateParamsForCreateTable() { + String sql; + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeSuccess(sql); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '2', \n" + + " 'efconstruction' = '1'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeSuccess(sql); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'invalid-index-type', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '16', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `index_type` must in (IVFPQ,HNSW)"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='invalid-dim', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '16', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `DIM` must be a integer"); + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='0', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '16', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `DIM` must be >= 1"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'invalid-metric-type', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '16', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `METRIC_TYPE` must be in [l2_distance, cosine_similarity]"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'invalid-is-vector-normed', \n" + + " 'M' = '16', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `IS_VECTOR_NORMED` must be `true` or `false`"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = 'invalid-M', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `M` must be a integer"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '1', \n" + + " 'efconstruction' = '40'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `M` must be >= 2"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '1', \n" + + " 'efconstruction' = 'invalid-efconstruction'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `EFCONSTRUCTION` must be a integer"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '1', \n" + + " 'efconstruction' = '0'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `EFCONSTRUCTION` must be >= 1"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '2', \n" + + " 'efconstruction' = '1',\n" + + " 'Nbits' = '8'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Index params [NBITS] should not define with HNSW"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = 'invalid-Nbits',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `NBITS` must be a integer"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '2',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `NBITS` must be 8"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = 'invalid-Nlist', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `NLIST` must be a integer"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '0',\n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Value of `NLIST` must be >= 1"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', 'invalid-key'='10', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "Unknown index param: `INVALID-KEY`"); + } + + @Test + public void testValidateParamsForAlterTable() throws Exception { + AnalyzeTestUtil.getStarRocksAssert().withTable("CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"); + String sql; + + try { + sql = "ALTER TABLE vector_t1 ADD INDEX index_vector1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n"; + analyzeSuccess(sql); + + sql = "ALTER TABLE vector_t1 ADD INDEX index_vector1 (v1) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '2', \n" + + " 'efconstruction' = '1'\n" + + " )\n"; + analyzeSuccess(sql); + + sql = "ALTER TABLE vector_t1 ADD INDEX index_vector1 (v1) USING VECTOR (\n" + + " 'index_type' = 'aIVFPQ', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16' \n" + + " )\n"; + analyzeSuccess(sql); + + StatementBase statement = SqlParser.parseSingleStatement(sql, connectContext.getSessionVariable().getSqlMode()); + StmtExecutor stmtExecutor = new StmtExecutor(connectContext, statement); + stmtExecutor.execute(); + assertThat(connectContext.getState().getErrType()).isEqualTo(QueryState.ErrType.INTERNAL_ERR); + assertThat(connectContext.getState().getErrorMessage()).contains("Value of `index_type` must in (IVFPQ,HNSW)"); + + } finally { + AnalyzeTestUtil.getStarRocksAssert().dropTables(List.of("vector_t1")); + } + } + + @Test + public void testOnlyOneVectorIndexForCreateTable() { + String sql; + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " ),\n" + + " INDEX index_v2 (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "At most one vector index is allowed for a table, but 2 were found: [index_v1, index_v2]"); + } + + @Test + public void testOnlyOneVectorIndexForAlterTableSuccess() throws Exception { + AnalyzeTestUtil.getStarRocksAssert().withTable("CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"); + String sql; + + try { + sql = "ALTER TABLE vector_t1 ADD INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n"; + StatementBase statement = SqlParser.parseSingleStatement(sql, connectContext.getSessionVariable().getSqlMode()); + StmtExecutor stmtExecutor = new StmtExecutor(connectContext, statement); + stmtExecutor.execute(); + assertThat(connectContext.getState().isError()).isFalse(); + } finally { + AnalyzeTestUtil.getStarRocksAssert().dropTables(List.of("vector_t1")); + } + } + + @Test + public void testOnlyOneVectorIndexForAlterTableFail() throws Exception { + AnalyzeTestUtil.getStarRocksAssert().withTable("CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16',\n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"); + String sql; + + try { + { + sql = "ALTER TABLE vector_t1 ADD INDEX index_v2 (v2) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n"; + StatementBase statement = SqlParser.parseSingleStatement(sql, connectContext.getSessionVariable().getSqlMode()); + StmtExecutor stmtExecutor = new StmtExecutor(connectContext, statement); + stmtExecutor.execute(); + assertThat(connectContext.getState().getErrorMessage()).contains( + "At most one vector index is allowed for a table, but there is already a vector index [index_v1]"); + } + } finally { + AnalyzeTestUtil.getStarRocksAssert().dropTables(List.of("vector_t1")); + } + } + + @Test + public void testCreateOnNullableColumn() { + String sql; + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1"; + analyzeFail(sql, "The vector index can only build on non-nullable column"); + } + + @Test + public void testIVFPQ() { + String sql; + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='4', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16'" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"; + analyzeFail(sql, "`M_IVFPQ` is required for IVFPQ index"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='10', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '3' \n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"; + analyzeFail(sql, "`DIM` should be a multiple of `M_IVFPQ` for IVFPQ index"); + + sql = "CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='10', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16', \n" + + " 'M_IVFPQ' = '2' \n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"; + analyzeSuccess(sql); + } + + @Test + public void testShowIndex() throws Exception { + String sql; + String show; + + AnalyzeTestUtil.getStarRocksAssert().withTable("CREATE TABLE vector_t1 (\n" + + " id bigint(20) NOT NULL,\n" + + " v1 ARRAY NOT NULL,\n" + + " v2 ARRAY NOT NULL,\n" + + " INDEX index_v1 (v1) USING VECTOR (\n" + + " 'index_type' = 'IVFPQ', \n" + + " 'dim'='10', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'Nbits' = '8',\n" + + " 'Nlist' = '16',\n" + + " 'M_IVFPQ' = '2'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"); + show = AnalyzeTestUtil.getStarRocksAssert().showCreateTable("show create table vector_t1"); + assertThat(show).contains( + "INDEX index_v1 (`v1`) USING VECTOR(\"dim\" = \"10\", \"index_type\" = \"ivfpq\", " + + "\"is_vector_normed\" = \"false\", \"m_ivfpq\" = \"2\", \"metric_type\" = \"l2_distance\", " + + "\"nbits\" = \"8\", \"nlist\" = \"16\")"); + + sql = "CREATE TABLE vector_t2 (\n" + + " id bigint(20) NOT NULL,\n" + + " vector ARRAY NOT NULL,\n" + + " INDEX index_vector (vector) USING VECTOR (\n" + + " 'index_type' = 'HNSW', \n" + + " 'dim'='5', \n" + + " 'metric_type' = 'l2_distance', \n" + + " 'is_vector_normed' = 'false', \n" + + " 'M' = '2', \n" + + " 'efconstruction' = '1'\n" + + " )\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES ('replication_num'='1');"; + AnalyzeTestUtil.getStarRocksAssert().withTable(sql); + show = AnalyzeTestUtil.getStarRocksAssert().showCreateTable("show create table vector_t2"); + assertThat(show).contains( + "INDEX index_vector (`vector`) USING VECTOR(\"dim\" = \"5\", \"efconstruction\" = \"1\", " + + "\"index_type\" = \"hnsw\", \"is_vector_normed\" = \"false\", \"m\" = \"2\", " + + "\"metric_type\" = \"l2_distance\")"); + + AnalyzeTestUtil.getStarRocksAssert().dropTables(List.of("vector_t1", "vector_t2")); + } +} diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 17d90fe4e48dc..934316b85993c 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -570,6 +570,7 @@ struct TVectorSearchOptions { 8: optional bool use_ivfpq; 9: optional double pq_refine_factor; 10: optional double k_factor; + 11: optional i32 vector_slot_id; } enum SampleMethod { diff --git a/test/sql/test_vector_index/R/test_vector_index b/test/sql/test_vector_index/R/test_vector_index index 09c176a587166..1b8b6b8e7cc77 100644 --- a/test/sql/test_vector_index/R/test_vector_index +++ b/test/sql/test_vector_index/R/test_vector_index @@ -1,4 +1,4 @@ --- name: test_create_vector_index +-- name: test_create_vector_index @sequential ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); -- result: -- !result @@ -18,7 +18,14 @@ PROPERTIES ( ); -- result: -- !result -CREATE INDEX index_vector2 ON t_test_vector_table (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="5", "nlist" = "256", "nbits"="10"); +DROP INDEX index_vector1 ON t_test_vector_table; +-- result: +-- !result +function: wait_alter_table_finish() +-- result: +None +-- !result +CREATE INDEX index_vector2 ON t_test_vector_table (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="4", "nlist" = "256", "nbits"="8", "M_IVFPQ"="2"); -- result: -- !result function: wait_alter_table_finish() @@ -32,7 +39,7 @@ function: wait_alter_table_finish() -- result: None -- !result -ALTER TABLE t_test_vector_table add index index_vector2 (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="5", "nlist" = "256", "nbits"="10"); +ALTER TABLE t_test_vector_table add index index_vector2 (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="4", "nlist" = "256", "nbits"="8", "M_IVFPQ"="2"); -- result: -- !result function: wait_alter_table_finish() @@ -50,8 +57,14 @@ None DROP TABLE t_test_vector_table; -- result: -- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result --- name: test_vector_index +-- name: test_vector_index @sequential +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); +-- result: +-- !result CREATE TABLE `t_test_vector_table` ( `id` bigint(20) NOT NULL COMMENT "", `vector1` ARRAY NOT NULL COMMENT "", @@ -88,4 +101,8 @@ select * from (select id, approx_l2_distance([1,1,1,1,1], vector1) score from t_ DROP TABLE t_test_vector_table; -- result: --- !result \ No newline at end of file +-- !result + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result diff --git a/test/sql/test_vector_index/R/test_vector_index_hnsw b/test/sql/test_vector_index/R/test_vector_index_hnsw new file mode 100644 index 0000000000000..8ad6d634257b1 --- /dev/null +++ b/test/sql/test_vector_index/R/test_vector_index_hnsw @@ -0,0 +1,484 @@ +-- name: test_vector_index_hnsw @sequential +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); +-- result: +-- !result +CREATE TABLE __row_util_base ( + k1 bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`k1`) +DISTRIBUTED BY HASH(`k1`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +insert into __row_util_base select generate_series from TABLE(generate_series(0, 10000 - 1)); +-- result: +-- !result +insert into __row_util_base select * from __row_util_base; -- 20000 +insert into __row_util_base select * from __row_util_base; -- 40000 +insert into __row_util_base select * from __row_util_base; -- 80000 +insert into __row_util_base select * from __row_util_base; -- 160000 +insert into __row_util_base select * from __row_util_base; -- 320000 +insert into __row_util_base select * from __row_util_base; -- 640000 + +CREATE TABLE __row_util ( + idx bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`idx`) +DISTRIBUTED BY HASH(`idx`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +insert into __row_util +select + row_number() over() as idx +from __row_util_base; +-- result: +-- !result +CREATE TABLE t2 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "M" = "160", + "efconstruction" = "400") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +insert into t2 +select + idx, + array_generate(10000, 10004), + array_generate(10000, 10004), + idx +from __row_util +order by idx +limit 20; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 21 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 0.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 0.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 0.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 0.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 0.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 0.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 0.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 0.0 +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 0.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 0.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 0.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 0.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 0.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 0.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 0.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 0.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 0.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 0.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 0.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +order by dis; +-- result: +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 500100000.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 500100000.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 500100000.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 500100000.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 500100000.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 500100000.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 500100000.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 500100000.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 500100000.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 500100000.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 500100000.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 500100000.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 500100000.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 500100000.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 500100000.0 +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 500100000.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 500100000.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 500100000.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 500100000.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 500100000.0 +-- !result +insert into t2 +select + idx + 20, + array_repeat(idx, 5), + array_repeat(idx, 5), + idx + 20 +from __row_util; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 125.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 180.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 245.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 320.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 405.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +order by dis limit 10; +-- result: +640020 [640000,640000,640000,640000,640000] [640000,640000,640000,640000,640000] 640020 20480.0 +640019 [639999,639999,639999,639999,639999] [639999,639999,639999,639999,639999] 640019 21125.0 +640018 [639998,639998,639998,639998,639998] [639998,639998,639998,639998,639998] 640018 21780.0 +640017 [639997,639997,639997,639997,639997] [639997,639997,639997,639997,639997] 640017 22445.0 +640016 [639996,639996,639996,639996,639996] [639996,639996,639996,639996,639996] 640016 23120.0 +640015 [639995,639995,639995,639995,639995] [639995,639995,639995,639995,639995] 640015 23805.0 +640014 [639994,639994,639994,639994,639994] [639994,639994,639994,639994,639994] 640014 24500.0 +640013 [639993,639993,639993,639993,639993] [639993,639993,639993,639993,639993] 640013 25205.0 +640012 [639992,639992,639992,639992,639992] [639992,639992,639992,639992,639992] 640012 25920.0 +640011 [639991,639991,639991,639991,639991] [639991,639991,639991,639991,639991] 640011 26645.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +order by dis limit 10; +-- result: +23 [3,3,3,3,3] [3,3,3,3,3] 23 10.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 15.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 15.0 +21 [1,1,1,1,1] [1,1,1,1,1] 21 30.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 30.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 55.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 90.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 135.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 190.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 255.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 0.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 0.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 0.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 0.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 0.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 0.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 0.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 0.0 +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 0.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 0.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 0.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 0.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 0.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 0.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 0.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 0.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 0.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 0.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 0.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 0 +order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +23 [3,3,3,3,3] [3,3,3,3,3] 23 10.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 15.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 15.0 +21 [1,1,1,1,1] [1,1,1,1,1] 21 30.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 30.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 55.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 90.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 0.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 0.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 0.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 0.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 0.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 0.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 0.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 0.0 +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 0.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 0.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 0.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 0.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 0.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 0.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 0.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 0.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 0.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 0.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 0.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis >= 100 and dis <= 1000 +order by dis limit 10; +-- result: +26 [6,6,6,6,6] [6,6,6,6,6] 26 125.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 180.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 245.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 320.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 405.0 +31 [11,11,11,11,11] [11,11,11,11,11] 31 500.0 +32 [12,12,12,12,12] [12,12,12,12,12] 32 605.0 +33 [13,13,13,13,13] [13,13,13,13,13] 33 720.0 +34 [14,14,14,14,14] [14,14,14,14,14] 34 845.0 +35 [15,15,15,15,15] [15,15,15,15,15] 35 980.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and id >= 0 +order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +order by dis, id limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 125.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 180.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 245.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 320.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 405.0 +-- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 125.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 180.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 245.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 320.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 405.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +order by dis limit 10; +-- result: +640020 [640000,640000,640000,640000,640000] [640000,640000,640000,640000,640000] 640020 20480.0 +640019 [639999,639999,639999,639999,639999] [639999,639999,639999,639999,639999] 640019 21125.0 +640018 [639998,639998,639998,639998,639998] [639998,639998,639998,639998,639998] 640018 21780.0 +640017 [639997,639997,639997,639997,639997] [639997,639997,639997,639997,639997] 640017 22445.0 +640016 [639996,639996,639996,639996,639996] [639996,639996,639996,639996,639996] 640016 23120.0 +640015 [639995,639995,639995,639995,639995] [639995,639995,639995,639995,639995] 640015 23805.0 +640014 [639994,639994,639994,639994,639994] [639994,639994,639994,639994,639994] 640014 24500.0 +640013 [639993,639993,639993,639993,639993] [639993,639993,639993,639993,639993] 640013 25205.0 +640012 [639992,639992,639992,639992,639992] [639992,639992,639992,639992,639992] 640012 25920.0 +640011 [639991,639991,639991,639991,639991] [639991,639991,639991,639991,639991] 640011 26645.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +order by dis limit 10; +-- result: +23 [3,3,3,3,3] [3,3,3,3,3] 23 10.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 15.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 15.0 +21 [1,1,1,1,1] [1,1,1,1,1] 21 30.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 30.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 55.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 90.0 +28 [8,8,8,8,8] [8,8,8,8,8] 28 135.0 +29 [9,9,9,9,9] [9,9,9,9,9] 29 190.0 +30 [10,10,10,10,10] [10,10,10,10,10] 30 255.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 0.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 0.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 0.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 0.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 0.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 0.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 0.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 0.0 +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 0.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 0.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 0.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 0.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 0.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 0.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 0.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 0.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 0.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 0.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 0.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +21 [1,1,1,1,1] [1,1,1,1,1] 21 0.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 5.0 +23 [3,3,3,3,3] [3,3,3,3,3] 23 20.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 45.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 80.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +23 [3,3,3,3,3] [3,3,3,3,3] 23 10.0 +22 [2,2,2,2,2] [2,2,2,2,2] 22 15.0 +24 [4,4,4,4,4] [4,4,4,4,4] 24 15.0 +21 [1,1,1,1,1] [1,1,1,1,1] 21 30.0 +25 [5,5,5,5,5] [5,5,5,5,5] 25 30.0 +26 [6,6,6,6,6] [6,6,6,6,6] 26 55.0 +27 [7,7,7,7,7] [7,7,7,7,7] 27 90.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 1 0.0 +2 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 2 0.0 +3 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 3 0.0 +4 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 4 0.0 +5 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 5 0.0 +6 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 6 0.0 +7 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 7 0.0 +8 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 8 0.0 +9 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 9 0.0 +10 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 10 0.0 +11 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 11 0.0 +12 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 12 0.0 +13 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 13 0.0 +14 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 14 0.0 +15 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 15 0.0 +16 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 16 0.0 +17 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 17 0.0 +18 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 18 0.0 +19 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 19 0.0 +20 [10000,10001,10002,10003,10004] [10000,10001,10002,10003,10004] 20 0.0 +-- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result \ No newline at end of file diff --git a/test/sql/test_vector_index/R/test_vector_index_insert b/test/sql/test_vector_index/R/test_vector_index_insert new file mode 100644 index 0000000000000..31de99e0e4c6b --- /dev/null +++ b/test/sql/test_vector_index/R/test_vector_index_insert @@ -0,0 +1,170 @@ +-- name: test_vector_index_insert @sequential +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); +-- result: +-- !result +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "cosine_similarity", + "is_vector_normed" = "true", + "M" = "16", + "efconstruction" = "40") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +INSERT into t1 values + (1, null, null); +-- result: +[REGEX].*Insert has filtered data.* +-- !result +INSERT into t1 values + (1, [null, null, null, null, null], [1,2,3,4,5]); +-- result: +[REGEX].*The input vector is not normalized but `metric_type` is cosine_similarity and `is_vector_normed` is true.* +-- !result +INSERT into t1 values + (1, [1,2,3,4], [1,2,3,4]); +-- result: +[REGEX].*The dimensions of the vector written are inconsistent, index dim is 5 but data dim is 4.* +-- !result +INSERT into t1 values + (1, [], []); +-- result: +[REGEX].*The dimensions of the vector written are inconsistent, index dim is 5 but data dim is 0.* +-- !result +INSERT INTO t1 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]); +-- result: +[REGEX].*The input vector is not normalized but `metric_type` is cosine_similarity and `is_vector_normed` is true.* +-- !result +INSERT INTO t1 values + (1, [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242], + [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242]), + (2, [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093], + [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093]), + (3, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152]), + (4, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + null); +-- result: +-- !result +INSERT INTO t1 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]), + (3, null, null); +-- result: +[REGEX].*The input vector is not normalized but `metric_type` is cosine_similarity and `is_vector_normed` is true.* +-- !result +INSERT INTO t1 values + (1, [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242], + [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242]), + (4, null, null), + (2, [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093], + [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093]), + (3, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152]), + (3, null, null), + (5, null, null), + (6, null, null), + (7, null, null), + (8, null, null), + (9, null, null), + (10, null, null); +-- result: +[REGEX].*Insert has filtered data.* +-- !result +select * from t1 order by id; +-- result: +1 [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] +2 [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] +3 [0.33686078,0.42107597,0.50529116,null,0.67372155] [0.33686078,0.42107597,0.50529116,null,0.67372155] +4 [0.33686078,0.42107597,0.50529116,null,0.67372155] None +-- !result +CREATE TABLE t2 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "cosine_similarity", + "is_vector_normed" = "false", + "M" = "16", + "efconstruction" = "40") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +INSERT INTO t2 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]), + (3, [4,5,6,null,8], [4,5,6,null,8]), + (4, [null, null, null, null], [null, null, null, null]), + (5, [4,5,6,7,8], null); +-- result: +[REGEX].*The dimensions of the vector written are inconsistent, index dim is 5 but data dim is 4.* +-- !result +INSERT INTO t2 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7], [4,5,6,7,8]), + (3, [4,5,6,null,8], [4,5,6,null,8]), + (4, [null, null, null, null], [null, null, null, null]), + (5, [4,5,6,7,8], null); +-- result: +[REGEX].*The dimensions of the vector written are inconsistent, index dim is 5 but data dim is 4.* +-- !result +select * from t2 order by id; +-- result: +-- !result +insert into t1 select * from t2; +-- result: +-- !result +insert into t1 select * from t1; +-- result: +-- !result +select * from t1 order by id, v1, v2; +-- result: +1 [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] +1 [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] +2 [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] +2 [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] +3 [0.33686078,0.42107597,0.50529116,null,0.67372155] [0.33686078,0.42107597,0.50529116,null,0.67372155] +3 [0.33686078,0.42107597,0.50529116,null,0.67372155] [0.33686078,0.42107597,0.50529116,null,0.67372155] +4 [0.33686078,0.42107597,0.50529116,null,0.67372155] None +4 [0.33686078,0.42107597,0.50529116,null,0.67372155] None +-- !result +insert into t2 select * from t1; +-- result: +-- !result +insert into t2 select id, v2, v1 from t2; +-- result: +[REGEX].*Insert has filtered data.* +-- !result +select * from t2 order by id, v1, v2; +-- result: +1 [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] +1 [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] [0.13483997,0.26967993,0.40451992,0.53935987,0.6741999] +2 [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] +2 [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] [0.2901905,0.36273813,0.43528575,0.50783336,0.580381] +3 [0.33686078,0.42107597,0.50529116,null,0.67372155] [0.33686078,0.42107597,0.50529116,null,0.67372155] +3 [0.33686078,0.42107597,0.50529116,null,0.67372155] [0.33686078,0.42107597,0.50529116,null,0.67372155] +4 [0.33686078,0.42107597,0.50529116,null,0.67372155] None +4 [0.33686078,0.42107597,0.50529116,null,0.67372155] None +-- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result diff --git a/test/sql/test_vector_index/R/test_vector_index_ivfpq b/test/sql/test_vector_index/R/test_vector_index_ivfpq new file mode 100644 index 0000000000000..999d35a0936c9 --- /dev/null +++ b/test/sql/test_vector_index/R/test_vector_index_ivfpq @@ -0,0 +1,536 @@ +-- name: test_vector_index_ivfpq @sequential +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); +-- result: +-- !result +CREATE TABLE __row_util_base ( + k1 bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`k1`) +DISTRIBUTED BY HASH(`k1`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +insert into __row_util_base select generate_series from TABLE(generate_series(0, 10000 - 1)); +-- result: +-- !result +insert into __row_util_base select * from __row_util_base; -- 20000 +insert into __row_util_base select * from __row_util_base; -- 40000 +insert into __row_util_base select * from __row_util_base; -- 80000 +insert into __row_util_base select * from __row_util_base; -- 160000 +insert into __row_util_base select * from __row_util_base; -- 320000 +insert into __row_util_base select * from __row_util_base; -- 640000 + +CREATE TABLE __row_util ( + idx bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`idx`) +DISTRIBUTED BY HASH(`idx`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +insert into __row_util +select + row_number() over() as idx +from __row_util_base; +-- result: +-- !result +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "IVFPQ", + "dim"="4", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "nbits" = "8", + "nlist" = "40", + "M_IVFPQ" = "2") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +-- !result +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "ivfpq", + "dim"="4", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "nbits" = "8", + "nlist" = "16", + "M_IVFPQ" = "2") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); +-- result: +E: (1050, "Getting analyzing error. Detail message: Table 't1' already exists.") +-- !result +insert into t1 +select + idx, + array_generate(10000, 10003), + array_generate(10000, 10003), + idx +from __row_util +order by idx +limit 20; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 21 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003] [10000,10001,10002,10003] 1 0.0 +2 [10000,10001,10002,10003] [10000,10001,10002,10003] 2 0.0 +3 [10000,10001,10002,10003] [10000,10001,10002,10003] 3 0.0 +4 [10000,10001,10002,10003] [10000,10001,10002,10003] 4 0.0 +5 [10000,10001,10002,10003] [10000,10001,10002,10003] 5 0.0 +6 [10000,10001,10002,10003] [10000,10001,10002,10003] 6 0.0 +7 [10000,10001,10002,10003] [10000,10001,10002,10003] 7 0.0 +8 [10000,10001,10002,10003] [10000,10001,10002,10003] 8 0.0 +9 [10000,10001,10002,10003] [10000,10001,10002,10003] 9 0.0 +10 [10000,10001,10002,10003] [10000,10001,10002,10003] 10 0.0 +11 [10000,10001,10002,10003] [10000,10001,10002,10003] 11 0.0 +12 [10000,10001,10002,10003] [10000,10001,10002,10003] 12 0.0 +13 [10000,10001,10002,10003] [10000,10001,10002,10003] 13 0.0 +14 [10000,10001,10002,10003] [10000,10001,10002,10003] 14 0.0 +15 [10000,10001,10002,10003] [10000,10001,10002,10003] 15 0.0 +16 [10000,10001,10002,10003] [10000,10001,10002,10003] 16 0.0 +17 [10000,10001,10002,10003] [10000,10001,10002,10003] 17 0.0 +18 [10000,10001,10002,10003] [10000,10001,10002,10003] 18 0.0 +19 [10000,10001,10002,10003] [10000,10001,10002,10003] 19 0.0 +20 [10000,10001,10002,10003] [10000,10001,10002,10003] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +order by dis; +-- result: +8 [10000,10001,10002,10003] [10000,10001,10002,10003] 8 400040000.0 +17 [10000,10001,10002,10003] [10000,10001,10002,10003] 17 400040000.0 +1 [10000,10001,10002,10003] [10000,10001,10002,10003] 1 400040000.0 +3 [10000,10001,10002,10003] [10000,10001,10002,10003] 3 400040000.0 +5 [10000,10001,10002,10003] [10000,10001,10002,10003] 5 400040000.0 +4 [10000,10001,10002,10003] [10000,10001,10002,10003] 4 400040000.0 +6 [10000,10001,10002,10003] [10000,10001,10002,10003] 6 400040000.0 +9 [10000,10001,10002,10003] [10000,10001,10002,10003] 9 400040000.0 +16 [10000,10001,10002,10003] [10000,10001,10002,10003] 16 400040000.0 +15 [10000,10001,10002,10003] [10000,10001,10002,10003] 15 400040000.0 +13 [10000,10001,10002,10003] [10000,10001,10002,10003] 13 400040000.0 +20 [10000,10001,10002,10003] [10000,10001,10002,10003] 20 400040000.0 +12 [10000,10001,10002,10003] [10000,10001,10002,10003] 12 400040000.0 +2 [10000,10001,10002,10003] [10000,10001,10002,10003] 2 400040000.0 +7 [10000,10001,10002,10003] [10000,10001,10002,10003] 7 400040000.0 +14 [10000,10001,10002,10003] [10000,10001,10002,10003] 14 400040000.0 +10 [10000,10001,10002,10003] [10000,10001,10002,10003] 10 400040000.0 +19 [10000,10001,10002,10003] [10000,10001,10002,10003] 19 400040000.0 +11 [10000,10001,10002,10003] [10000,10001,10002,10003] 11 400040000.0 +18 [10000,10001,10002,10003] [10000,10001,10002,10003] 18 400040000.0 +-- !result +insert into t1 +select + idx + 20, + array_repeat(idx, 4), + array_repeat(idx, 4), + idx + 20 +from __row_util; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select id, v1, dis from w1 order by dis limit 10; +-- result: +21 [1,1,1,1] 0.0 +22 [2,2,2,2] 4.0 +23 [3,3,3,3] 16.0 +24 [4,4,4,4] 36.0 +25 [5,5,5,5] 64.0 +26 [6,6,6,6] 100.0 +27 [7,7,7,7] 144.0 +28 [8,8,8,8] 196.0 +29 [9,9,9,9] 256.0 +30 [10,10,10,10] 324.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +order by dis limit 10; +-- result: +640020 [640000,640000,640000,640000] [640000,640000,640000,640000] 640020 16384.0 +640019 [639999,639999,639999,639999] [639999,639999,639999,639999] 640019 16900.0 +640018 [639998,639998,639998,639998] [639998,639998,639998,639998] 640018 17424.0 +640017 [639997,639997,639997,639997] [639997,639997,639997,639997] 640017 17956.0 +640016 [639996,639996,639996,639996] [639996,639996,639996,639996] 640016 18496.0 +640015 [639995,639995,639995,639995] [639995,639995,639995,639995] 640015 19044.0 +640014 [639994,639994,639994,639994] [639994,639994,639994,639994] 640014 19600.0 +640013 [639993,639993,639993,639993] [639993,639993,639993,639993] 640013 20164.0 +640012 [639992,639992,639992,639992] [639992,639992,639992,639992] 640012 20736.0 +640011 [639991,639991,639991,639991] [639991,639991,639991,639991] 640011 21316.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +order by dis limit 10; +-- result: +22 [2,2,2,2] [2,2,2,2] 22 6.0 +23 [3,3,3,3] [3,3,3,3] 23 6.0 +24 [4,4,4,4] [4,4,4,4] 24 14.0 +21 [1,1,1,1] [1,1,1,1] 21 14.0 +25 [5,5,5,5] [5,5,5,5] 25 30.0 +26 [6,6,6,6] [6,6,6,6] 26 54.0 +27 [7,7,7,7] [7,7,7,7] 27 86.0 +28 [8,8,8,8] [8,8,8,8] 28 126.0 +29 [9,9,9,9] [9,9,9,9] 29 174.0 +30 [10,10,10,10] [10,10,10,10] 30 230.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003] [10000,10001,10002,10003] 1 0.0 +2 [10000,10001,10002,10003] [10000,10001,10002,10003] 2 0.0 +3 [10000,10001,10002,10003] [10000,10001,10002,10003] 3 0.0 +4 [10000,10001,10002,10003] [10000,10001,10002,10003] 4 0.0 +5 [10000,10001,10002,10003] [10000,10001,10002,10003] 5 0.0 +6 [10000,10001,10002,10003] [10000,10001,10002,10003] 6 0.0 +7 [10000,10001,10002,10003] [10000,10001,10002,10003] 7 0.0 +8 [10000,10001,10002,10003] [10000,10001,10002,10003] 8 0.0 +9 [10000,10001,10002,10003] [10000,10001,10002,10003] 9 0.0 +10 [10000,10001,10002,10003] [10000,10001,10002,10003] 10 0.0 +11 [10000,10001,10002,10003] [10000,10001,10002,10003] 11 0.0 +12 [10000,10001,10002,10003] [10000,10001,10002,10003] 12 0.0 +13 [10000,10001,10002,10003] [10000,10001,10002,10003] 13 0.0 +14 [10000,10001,10002,10003] [10000,10001,10002,10003] 14 0.0 +15 [10000,10001,10002,10003] [10000,10001,10002,10003] 15 0.0 +16 [10000,10001,10002,10003] [10000,10001,10002,10003] 16 0.0 +17 [10000,10001,10002,10003] [10000,10001,10002,10003] 17 0.0 +18 [10000,10001,10002,10003] [10000,10001,10002,10003] 18 0.0 +19 [10000,10001,10002,10003] [10000,10001,10002,10003] 19 0.0 +20 [10000,10001,10002,10003] [10000,10001,10002,10003] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select v1, dis from w1 +where dis <= 0 +order by dis limit 10; +-- result: +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 +order by dis limit 10; +-- result: +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +[10000,10001,10002,10003] 400040000.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +[10000,10001,10002,10003] 1587915000000.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +[10000,10001,10002,10003] 399920000.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select v1, dis from w2 order by dis, id; +-- result: +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +[10000,10001,10002,10003] 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis >= 100 and dis <= 1000 +order by dis limit 10; +-- result: +26 [6,6,6,6] [6,6,6,6] 26 100.0 +27 [7,7,7,7] [7,7,7,7] 27 144.0 +28 [8,8,8,8] [8,8,8,8] 28 196.0 +29 [9,9,9,9] [9,9,9,9] 29 256.0 +30 [10,10,10,10] [10,10,10,10] 30 324.0 +31 [11,11,11,11] [11,11,11,11] 31 400.0 +32 [12,12,12,12] [12,12,12,12] 32 484.0 +33 [13,13,13,13] [13,13,13,13] 33 576.0 +34 [14,14,14,14] [14,14,14,14] 34 676.0 +35 [15,15,15,15] [15,15,15,15] 35 784.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis <= 100 and id >= 0 +order by dis limit 10; +-- result: +21 [1,1,1,1] [1,1,1,1] 21 0.0 +22 [2,2,2,2] [2,2,2,2] 22 4.0 +23 [3,3,3,3] [3,3,3,3] 23 16.0 +24 [4,4,4,4] [4,4,4,4] 24 36.0 +25 [5,5,5,5] [5,5,5,5] 25 64.0 +26 [6,6,6,6] [6,6,6,6] 26 100.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +order by dis, id limit 10; +-- result: +21 [1,1,1,1] [1,1,1,1] 21 0.0 +22 [2,2,2,2] [2,2,2,2] 22 4.0 +23 [3,3,3,3] [3,3,3,3] 23 16.0 +24 [4,4,4,4] [4,4,4,4] 24 36.0 +25 [5,5,5,5] [5,5,5,5] 25 64.0 +26 [6,6,6,6] [6,6,6,6] 26 100.0 +27 [7,7,7,7] [7,7,7,7] 27 144.0 +28 [8,8,8,8] [8,8,8,8] 28 196.0 +29 [9,9,9,9] [9,9,9,9] 29 256.0 +30 [10,10,10,10] [10,10,10,10] 30 324.0 +-- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 order by dis limit 10; +-- result: +21 [1,1,1,1] [1,1,1,1] 21 0.0 +22 [2,2,2,2] [2,2,2,2] 22 4.0 +23 [3,3,3,3] [3,3,3,3] 23 16.0 +24 [4,4,4,4] [4,4,4,4] 24 36.0 +25 [5,5,5,5] [5,5,5,5] 25 64.0 +26 [6,6,6,6] [6,6,6,6] 26 100.0 +27 [7,7,7,7] [7,7,7,7] 27 144.0 +28 [8,8,8,8] [8,8,8,8] 28 196.0 +29 [9,9,9,9] [9,9,9,9] 29 256.0 +30 [10,10,10,10] [10,10,10,10] 30 324.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +order by dis limit 10; +-- result: +640020 [640000,640000,640000,640000] [640000,640000,640000,640000] 640020 16384.0 +640019 [639999,639999,639999,639999] [639999,639999,639999,639999] 640019 16900.0 +640018 [639998,639998,639998,639998] [639998,639998,639998,639998] 640018 17424.0 +640017 [639997,639997,639997,639997] [639997,639997,639997,639997] 640017 17956.0 +640016 [639996,639996,639996,639996] [639996,639996,639996,639996] 640016 18496.0 +640015 [639995,639995,639995,639995] [639995,639995,639995,639995] 640015 19044.0 +640014 [639994,639994,639994,639994] [639994,639994,639994,639994] 640014 19600.0 +640013 [639993,639993,639993,639993] [639993,639993,639993,639993] 640013 20164.0 +640012 [639992,639992,639992,639992] [639992,639992,639992,639992] 640012 20736.0 +640011 [639991,639991,639991,639991] [639991,639991,639991,639991] 640011 21316.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +order by dis limit 10; +-- result: +22 [2,2,2,2] [2,2,2,2] 22 6.0 +23 [3,3,3,3] [3,3,3,3] 23 6.0 +24 [4,4,4,4] [4,4,4,4] 24 14.0 +21 [1,1,1,1] [1,1,1,1] 21 14.0 +25 [5,5,5,5] [5,5,5,5] 25 30.0 +26 [6,6,6,6] [6,6,6,6] 26 54.0 +27 [7,7,7,7] [7,7,7,7] 27 86.0 +28 [8,8,8,8] [8,8,8,8] 28 126.0 +29 [9,9,9,9] [9,9,9,9] 29 174.0 +30 [10,10,10,10] [10,10,10,10] 30 230.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003] [10000,10001,10002,10003] 1 0.0 +2 [10000,10001,10002,10003] [10000,10001,10002,10003] 2 0.0 +3 [10000,10001,10002,10003] [10000,10001,10002,10003] 3 0.0 +4 [10000,10001,10002,10003] [10000,10001,10002,10003] 4 0.0 +5 [10000,10001,10002,10003] [10000,10001,10002,10003] 5 0.0 +6 [10000,10001,10002,10003] [10000,10001,10002,10003] 6 0.0 +7 [10000,10001,10002,10003] [10000,10001,10002,10003] 7 0.0 +8 [10000,10001,10002,10003] [10000,10001,10002,10003] 8 0.0 +9 [10000,10001,10002,10003] [10000,10001,10002,10003] 9 0.0 +10 [10000,10001,10002,10003] [10000,10001,10002,10003] 10 0.0 +11 [10000,10001,10002,10003] [10000,10001,10002,10003] 11 0.0 +12 [10000,10001,10002,10003] [10000,10001,10002,10003] 12 0.0 +13 [10000,10001,10002,10003] [10000,10001,10002,10003] 13 0.0 +14 [10000,10001,10002,10003] [10000,10001,10002,10003] 14 0.0 +15 [10000,10001,10002,10003] [10000,10001,10002,10003] 15 0.0 +16 [10000,10001,10002,10003] [10000,10001,10002,10003] 16 0.0 +17 [10000,10001,10002,10003] [10000,10001,10002,10003] 17 0.0 +18 [10000,10001,10002,10003] [10000,10001,10002,10003] 18 0.0 +19 [10000,10001,10002,10003] [10000,10001,10002,10003] 19 0.0 +20 [10000,10001,10002,10003] [10000,10001,10002,10003] 20 0.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +21 [1,1,1,1] [1,1,1,1] 21 0.0 +22 [2,2,2,2] [2,2,2,2] 22 4.0 +23 [3,3,3,3] [3,3,3,3] 23 16.0 +24 [4,4,4,4] [4,4,4,4] 24 36.0 +25 [5,5,5,5] [5,5,5,5] 25 64.0 +26 [6,6,6,6] [6,6,6,6] 26 100.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; +-- result: +22 [2,2,2,2] [2,2,2,2] 22 6.0 +23 [3,3,3,3] [3,3,3,3] 23 6.0 +24 [4,4,4,4] [4,4,4,4] 24 14.0 +21 [1,1,1,1] [1,1,1,1] 21 14.0 +25 [5,5,5,5] [5,5,5,5] 25 30.0 +26 [6,6,6,6] [6,6,6,6] 26 54.0 +27 [7,7,7,7] [7,7,7,7] 27 86.0 +-- !result +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; +-- result: +1 [10000,10001,10002,10003] [10000,10001,10002,10003] 1 0.0 +2 [10000,10001,10002,10003] [10000,10001,10002,10003] 2 0.0 +3 [10000,10001,10002,10003] [10000,10001,10002,10003] 3 0.0 +4 [10000,10001,10002,10003] [10000,10001,10002,10003] 4 0.0 +5 [10000,10001,10002,10003] [10000,10001,10002,10003] 5 0.0 +6 [10000,10001,10002,10003] [10000,10001,10002,10003] 6 0.0 +7 [10000,10001,10002,10003] [10000,10001,10002,10003] 7 0.0 +8 [10000,10001,10002,10003] [10000,10001,10002,10003] 8 0.0 +9 [10000,10001,10002,10003] [10000,10001,10002,10003] 9 0.0 +10 [10000,10001,10002,10003] [10000,10001,10002,10003] 10 0.0 +11 [10000,10001,10002,10003] [10000,10001,10002,10003] 11 0.0 +12 [10000,10001,10002,10003] [10000,10001,10002,10003] 12 0.0 +13 [10000,10001,10002,10003] [10000,10001,10002,10003] 13 0.0 +14 [10000,10001,10002,10003] [10000,10001,10002,10003] 14 0.0 +15 [10000,10001,10002,10003] [10000,10001,10002,10003] 15 0.0 +16 [10000,10001,10002,10003] [10000,10001,10002,10003] 16 0.0 +17 [10000,10001,10002,10003] [10000,10001,10002,10003] 17 0.0 +18 [10000,10001,10002,10003] [10000,10001,10002,10003] 18 0.0 +19 [10000,10001,10002,10003] [10000,10001,10002,10003] 19 0.0 +20 [10000,10001,10002,10003] [10000,10001,10002,10003] 20 0.0 +-- !result +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); +-- result: +-- !result \ No newline at end of file diff --git a/test/sql/test_vector_index/T/test_vector_index b/test/sql/test_vector_index/T/test_vector_index index f50e60de399ad..e4dae64b23523 100644 --- a/test/sql/test_vector_index/T/test_vector_index +++ b/test/sql/test_vector_index/T/test_vector_index @@ -1,5 +1,6 @@ --- name: test_create_vector_index +-- name: test_create_vector_index @sequential ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); + CREATE TABLE `t_test_vector_table` ( `id` bigint(20) NOT NULL COMMENT "", `vector1` ARRAY NOT NULL COMMENT "", @@ -14,15 +15,23 @@ PROPERTIES ( "replicated_storage" = "false", "compression" = "LZ4" ); -CREATE INDEX index_vector2 ON t_test_vector_table (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="5", "nlist" = "256", "nbits"="10"); + +DROP INDEX index_vector1 ON t_test_vector_table; + +CREATE INDEX index_vector2 ON t_test_vector_table (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="4", "nlist" = "256", "nbits"="8", "M_IVFPQ"="2"); DROP INDEX index_vector2 ON t_test_vector_table; -ALTER TABLE t_test_vector_table add index index_vector2 (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="5", "nlist" = "256", "nbits"="10"); +ALTER TABLE t_test_vector_table add index index_vector2 (vector2) USING VECTOR ("metric_type" = "l2_distance", "is_vector_normed" = "false", "index_type" = "ivfpq", "dim"="4", "nlist" = "256", "nbits"="8", "M_IVFPQ"="2"); ALTER TABLE t_test_vector_table drop index index_vector2; DROP TABLE t_test_vector_table; --- name: test_vector_index +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); + +-- name: test_vector_index @sequential + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); + CREATE TABLE `t_test_vector_table` ( `id` bigint(20) NOT NULL COMMENT "", `vector1` ARRAY NOT NULL COMMENT "", @@ -42,4 +51,7 @@ insert into t_test_vector_table values(2, [4,5,6,7,8]); select id, approx_l2_distance([1,1,1,1,1], vector1) from t_test_vector_table order by approx_l2_distance([1,1,1,1,1], vector1) limit 1; select * from (select id, approx_l2_distance([1,1,1,1,1], vector1) score from t_test_vector_table) a where score < 40 order by score limit 1; -DROP TABLE t_test_vector_table; \ No newline at end of file +DROP TABLE t_test_vector_table; + + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); diff --git a/test/sql/test_vector_index/T/test_vector_index_hnsw b/test/sql/test_vector_index/T/test_vector_index_hnsw new file mode 100644 index 0000000000000..65944332a7435 --- /dev/null +++ b/test/sql/test_vector_index/T/test_vector_index_hnsw @@ -0,0 +1,249 @@ +-- name: test_vector_index_hnsw @sequential + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); + + + +CREATE TABLE __row_util_base ( + k1 bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`k1`) +DISTRIBUTED BY HASH(`k1`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +insert into __row_util_base select generate_series from TABLE(generate_series(0, 10000 - 1)); +insert into __row_util_base select * from __row_util_base; -- 20000 +insert into __row_util_base select * from __row_util_base; -- 40000 +insert into __row_util_base select * from __row_util_base; -- 80000 +insert into __row_util_base select * from __row_util_base; -- 160000 +insert into __row_util_base select * from __row_util_base; -- 320000 +insert into __row_util_base select * from __row_util_base; -- 640000 + +CREATE TABLE __row_util ( + idx bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`idx`) +DISTRIBUTED BY HASH(`idx`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); + +insert into __row_util +select + row_number() over() as idx +from __row_util_base; + + +CREATE TABLE t2 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "M" = "160", + "efconstruction" = "400") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); + + +insert into t2 +select + idx, + array_generate(10000, 10004), + array_generate(10000, 10004), + idx +from __row_util +order by idx +limit 20; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 21 +) select * from w2 order by dis, id; + +-- cannot use vector index +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +order by dis; + +insert into t2 +select + idx + 20, + array_repeat(idx, 5), + array_repeat(idx, 5), + idx + 20 +from __row_util; + +-- basic queries. + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; + +-- queries predicates + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 0 +order by dis limit 10; + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; + + +-- cannot use vector index +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis >= 100 and dis <= 1000 +order by dis limit 10; + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and id >= 0 +order by dis limit 10; + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +order by dis, id limit 10; + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); + +-- basic queries. + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; + +-- queries predicates + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1, 1]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064, 640064]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4, 5]) as dis from t2 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003, 10004]) as dis from t2 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; + + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); diff --git a/test/sql/test_vector_index/T/test_vector_index_insert b/test/sql/test_vector_index/T/test_vector_index_insert new file mode 100644 index 0000000000000..c7a5b6e18d06b --- /dev/null +++ b/test/sql/test_vector_index/T/test_vector_index_insert @@ -0,0 +1,120 @@ +-- name: test_vector_index_insert @sequential + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); + +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "cosine_similarity", + "is_vector_normed" = "true", + "M" = "16", + "efconstruction" = "40") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT into t1 values + (1, null, null); + +INSERT into t1 values + (1, [null, null, null, null, null], [1,2,3,4,5]); + +INSERT into t1 values + (1, [1,2,3,4], [1,2,3,4]); + +INSERT into t1 values + (1, [], []); + +INSERT INTO t1 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]); + +INSERT INTO t1 values + (1, [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242], + [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242]), + (2, [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093], + [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093]), + (3, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152]), + (4, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + null); + + +INSERT INTO t1 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]), + (3, null, null); + +INSERT INTO t1 values + (1, [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242], + [0.13483997249264842, 0.26967994498529685, 0.40451991747794525, 0.5393598899705937, 0.674199862463242]), + (4, null, null), + (2, [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093], + [0.29019050004400465, 0.36273812505500586, 0.435285750066007, 0.5078333750770082, 0.5803810000880093]), + (3, [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152], + [0.3368607684266076, 0.42107596053325946, 0.5052911526399114, null, 0.6737215368532152]), + (3, null, null), + (5, null, null), + (6, null, null), + (7, null, null), + (8, null, null), + (9, null, null), + (10, null, null); + +select * from t1 order by id; + + +CREATE TABLE t2 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "hnsw", + "dim"="5", + "metric_type" = "cosine_similarity", + "is_vector_normed" = "false", + "M" = "16", + "efconstruction" = "40") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); + +INSERT INTO t2 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7,8], [4,5,6,7,8]), + (3, [4,5,6,null,8], [4,5,6,null,8]), + (4, [null, null, null, null], [null, null, null, null]), + (5, [4,5,6,7,8], null); + + +INSERT INTO t2 values + (1, [1,2,3,4,5], [1,2,3,4,5]), + (2, [4,5,6,7], [4,5,6,7,8]), + (3, [4,5,6,null,8], [4,5,6,null,8]), + (4, [null, null, null, null], [null, null, null, null]), + (5, [4,5,6,7,8], null); + + +select * from t2 order by id, v1, v2; + + +insert into t1 select * from t2; +insert into t1 select * from t1; +select * from t1 order by id, v1, v2; + +insert into t2 select * from t1; +insert into t2 select id, v2, v1 from t2; +select * from t2 order by id, v1, v2; + + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); diff --git a/test/sql/test_vector_index/T/test_vector_index_ivfpq b/test/sql/test_vector_index/T/test_vector_index_ivfpq new file mode 100644 index 0000000000000..f415e7c27fc23 --- /dev/null +++ b/test/sql/test_vector_index/T/test_vector_index_ivfpq @@ -0,0 +1,271 @@ +-- name: test_vector_index_ivfpq @sequential + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "true"); + + + +CREATE TABLE __row_util_base ( + k1 bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`k1`) +DISTRIBUTED BY HASH(`k1`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); +insert into __row_util_base select generate_series from TABLE(generate_series(0, 10000 - 1)); +insert into __row_util_base select * from __row_util_base; -- 20000 +insert into __row_util_base select * from __row_util_base; -- 40000 +insert into __row_util_base select * from __row_util_base; -- 80000 +insert into __row_util_base select * from __row_util_base; -- 160000 +insert into __row_util_base select * from __row_util_base; -- 320000 +insert into __row_util_base select * from __row_util_base; -- 640000 + +CREATE TABLE __row_util ( + idx bigint NULL +) ENGINE=OLAP +DUPLICATE KEY(`idx`) +DISTRIBUTED BY HASH(`idx`) BUCKETS 32 +PROPERTIES ( + "replication_num" = "1" +); + +insert into __row_util +select + row_number() over() as idx +from __row_util_base; + + +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "IVFPQ", + "dim"="4", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "nbits" = "8", + "nlist" = "40", + "M_IVFPQ" = "2") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); + + +CREATE TABLE t1 ( + id bigint(20) NOT NULL, + v1 ARRAY NOT NULL, + v2 ARRAY NOT NULL, + i1 bigint(20) NOT NULL, + INDEX index_vector (v1) USING VECTOR ( + "index_type" = "ivfpq", + "dim"="4", + "metric_type" = "l2_distance", + "is_vector_normed" = "false", + "nbits" = "8", + "nlist" = "16", + "M_IVFPQ" = "2") +) ENGINE=OLAP +DUPLICATE KEY(id) +DISTRIBUTED BY HASH(id) BUCKETS 64 +PROPERTIES ( + "replication_num" = "1" +); + + +insert into t1 +select + idx, + array_generate(10000, 10003), + array_generate(10000, 10003), + idx +from __row_util +order by idx +limit 20; + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 21 +) select * from w2 order by dis, id; + +-- cannot use vector index +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +order by dis; + +insert into t1 +select + idx + 20, + array_repeat(idx, 4), + array_repeat(idx, 4), + idx + 20 +from __row_util; + +-- basic queries. + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select id, v1, dis from w1 order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; + +-- queries predicates + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select v1, dis from w1 +where dis <= 0 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select v1, dis from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select v1, dis from w2 order by dis, id; + + +-- cannot use vector index +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis >= 100 and dis <= 1000 +order by dis limit 10; + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis <= 100 and id >= 0 +order by dis limit 10; + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +order by dis, id limit 10; + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false"); + +-- basic queries. + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + order by dis limit 20 +) select * from w2 order by dis, id; + +-- queries predicates + +with w1 as ( + select *, approx_l2_distance(v1, [1, 1, 1, 1]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [640064, 640064, 640064, 640064]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [1, 2, 3, 4]) as dis from t1 +) +select * from w1 +where dis <= 100 and dis <= 1000 +order by dis limit 10; + + +with w1 as ( + select *, approx_l2_distance(v1, [10000, 10001, 10002, 10003]) as dis from t1 +), w2 as ( + select * from w1 + where dis <= 100 and dis <= 1000 + order by dis limit 20 +) select * from w2 order by dis, id; + + +ADMIN SET FRONTEND CONFIG("enable_experimental_vector" = "false");