Skip to content

Commit

Permalink
extend unlimted offset for query iterator(#2418)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Dec 9, 2024
1 parent a225cbf commit f16345d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 36 deletions.
41 changes: 29 additions & 12 deletions examples/orm/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PICTURE = "picture"
CONSISTENCY_LEVEL = "Eventually"
LIMIT = 5
NUM_ENTITIES = 1000
NUM_ENTITIES = 10000
DIM = 8
CLEAR_EXIST = True

Expand Down Expand Up @@ -51,8 +51,7 @@ def re_create_collection(prepare_new_data: bool):
print(f"dropped existed collection{COLLECTION_NAME}")

fields = [
FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True,
auto_id=False, max_length=MAX_LENGTH),
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
Expand Down Expand Up @@ -80,10 +79,9 @@ def random_pk(filter_set: set, lower_bound: int, upper_bound: int) -> str:
def insert_data(collection):
rng = np.random.default_rng(seed=19530)
batch_count = 5
filter_set: set = {}
for i in range(batch_count):
entities = [
[random_pk(filter_set, 0, batch_count * NUM_ENTITIES) for _ in range(NUM_ENTITIES)],
[i for i in range(NUM_ENTITIES*i, NUM_ENTITIES*(i + 1))],
[int(ni % 100) for ni in range(NUM_ENTITIES)],
[float(ni) for ni in range(NUM_ENTITIES)],
rng.random((NUM_ENTITIES, DIM)),
Expand Down Expand Up @@ -117,7 +115,7 @@ def query_iterate_collection_no_offset(collection):

query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="false", print_iterator_cursor=False,
reduce_stop_for_best="false",
iterator_cp_file="/tmp/it_cp")
no_best_ids: set = set({})
page_idx = 0
Expand All @@ -136,7 +134,7 @@ def query_iterate_collection_no_offset(collection):
print("best---------------------------")
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=0, batch_size=5, consistency_level=CONSISTENCY_LEVEL,
reduce_stop_for_best="true", print_iterator_cursor=False, iterator_cp_file="/tmp/it_cp")
reduce_stop_for_best="true", iterator_cp_file="/tmp/it_cp")

best_ids: set = set({})
page_idx = 0
Expand All @@ -160,7 +158,23 @@ def query_iterate_collection_no_offset(collection):
def query_iterate_collection_with_offset(collection):
expr = f"10 <= {AGE} <= 14"
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
offset=10, batch_size=50, consistency_level=CONSISTENCY_LEVEL, print_iterator_cursor=True)
offset=10, batch_size=50, consistency_level=CONSISTENCY_LEVEL)
page_idx = 0
while True:
res = query_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
query_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def query_iterate_collection_with_large_offset(collection):
query_iterator = collection.query_iterator(output_fields=[USER_ID, AGE],
offset=48000, batch_size=50, consistency_level=CONSISTENCY_LEVEL)
page_idx = 0
while True:
res = query_iterator.next()
Expand All @@ -177,7 +191,7 @@ def query_iterate_collection_with_offset(collection):
def query_iterate_collection_with_limit(collection):
expr = f"10 <= {AGE} <= 44"
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
batch_size=80, limit=530, consistency_level=CONSISTENCY_LEVEL, print_iterator_cursor=True)
batch_size=80, limit=530, consistency_level=CONSISTENCY_LEVEL)
page_idx = 0
while True:
res = query_iterator.next()
Expand All @@ -191,6 +205,8 @@ def query_iterate_collection_with_limit(collection):
print(f"page{page_idx}-------------------------")




def search_iterator_collection(collection):
SEARCH_NQ = 1
DIM = 8
Expand All @@ -201,7 +217,7 @@ def search_iterator_collection(collection):
"params": {"nprobe": 10, "radius": 1.0},
}
search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=500,
output_fields=[USER_ID], print_iterator_cursor=True)
output_fields=[USER_ID])
page_idx = 0
while True:
res = search_iterator.next()
Expand All @@ -225,7 +241,7 @@ def search_iterator_collection_with_limit(collection):
"params": {"nprobe": 10, "radius": 1.0},
}
search_iterator = collection.search_iterator(vectors_to_search, PICTURE, search_params, batch_size=200, limit=755,
output_fields=[USER_ID], print_iterator_cursor=True)
output_fields=[USER_ID])
page_idx = 0
while True:
res = search_iterator.next()
Expand All @@ -240,11 +256,12 @@ def search_iterator_collection_with_limit(collection):


def main():
prepare_new_data = True
prepare_new_data = False
connections.connect("default", host=HOST, port=PORT)
collection = re_create_collection(prepare_new_data)
if prepare_new_data:
collection = prepare_data(collection)
query_iterate_collection_with_large_offset(collection)
query_iterate_collection_no_offset(collection)
query_iterate_collection_with_offset(collection)
query_iterate_collection_with_limit(collection)
Expand Down
1 change: 0 additions & 1 deletion pymilvus/orm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
PRINT_ITERATOR_CURSOR = "print_iterator_cursor"
DEFAULT_MAX_L2_DISTANCE = 99999999.0
DEFAULT_MIN_IP_DISTANCE = -99999999.0
DEFAULT_MAX_HAMMING_DISTANCE = 99999999.0
Expand Down
66 changes: 43 additions & 23 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import time
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
Expand Down Expand Up @@ -37,7 +38,6 @@
MILVUS_LIMIT,
OFFSET,
PARAMS,
PRINT_ITERATOR_CURSOR,
RADIUS,
RANGE_FILTER,
REDUCE_STOP_FOR_BEST,
Expand Down Expand Up @@ -113,7 +113,6 @@ def __init__(
self.__check_set_batch_size(batch_size)
self._limit = limit
self.__check_set_reduce_stop_for_best()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
self._returned_count = 0
self.__setup__pk_prop()
self.__set_up_expr(expr)
Expand All @@ -131,18 +130,42 @@ def __seek_to_offset(self):
if offset > 0:
seek_params = self._kwargs.copy()
seek_params[OFFSET] = 0
seek_params[MILVUS_LIMIT] = offset
res = self._conn.query(
collection_name=self._collection_name,
expr=self._expr,
output_field=self._output_fields,
partition_name=self._partition_names,
timeout=self._timeout,
**seek_params,
)
result_index = min(len(res), offset)
self.__update_cursor(res[:result_index])
seek_params[ITERATOR_FIELD] = "False"
seek_params[REDUCE_STOP_FOR_BEST] = "False"
start_time = time.time()

def seek_offset_by_batch(batch: int, expr: str) -> int:
seek_params[MILVUS_LIMIT] = batch
res = self._conn.query(
collection_name=self._collection_name,
expr=expr,
output_field=[],
partition_name=self._partition_names,
timeout=self._timeout,
**seek_params,
)
self.__update_cursor(res)
return len(res)

while offset > 0:
batch_size = min(MAX_BATCH_SIZE, offset)
next_expr = self.__setup_next_expr()
seeked_count = seek_offset_by_batch(batch_size, next_expr)
LOGGER.info(
f"seeked offset, seek_expr:{next_expr} batch_size:{batch_size} seeked_count:{seeked_count}"
)
if seeked_count == 0:
LOGGER.info(
"seek offset has drained all matched results for query iterator, break"
)
break
offset -= seeked_count
self._kwargs[OFFSET] = 0
seek_offset_duration = time.time() - start_time
LOGGER.info(
f"Finish seek offset for query iterator, offset:{offset}, current_pk_cursor:{self._next_id}, "
f"duration:{seek_offset_duration}"
)

def __init_cp_file_handler(self) -> bool:
mode = "w"
Expand Down Expand Up @@ -170,14 +193,14 @@ def __save_pk_cursor(self):
self._cp_file_handler = self._cp_file_path.open("w")
self._buffer_cursor_lines_number = 0
self.__save_mvcc_ts()
log.warning(
LOGGER.warning(
"iterator cp file is not existed any more, recreate for iteration, "
"do not remove this file manually!"
)
if self._buffer_cursor_lines_number >= 100:
self._cp_file_handler.seek(0)
self._cp_file_handler.truncate()
log.info(
LOGGER.info(
"cursor lines in cp file has exceeded 100 lines, truncate the file and rewrite"
)
self._buffer_cursor_lines_number = 0
Expand Down Expand Up @@ -229,7 +252,7 @@ def __setup_ts_by_request(self):
if res.extra is not None:
self._session_ts = res.extra.get(ITERATOR_SESSION_TS_FIELD, 0)
if self._session_ts <= 0:
log.warning("failed to get mvccTs from milvus server, use client-side ts instead")
LOGGER.warning("failed to get mvccTs from milvus server, use client-side ts instead")
self._session_ts = fall_back_to_latest_session_ts()
self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts

Expand Down Expand Up @@ -291,8 +314,7 @@ def next(self):
else:
iterator_cache.release_cache(self._cache_id_in_use)
current_expr = self.__setup_next_expr()
if self._print_iterator_cursor:
log.info(f"query_iterator_next_expr:{current_expr}")
LOGGER.debug(f"query_iterator_next_expr:{current_expr}")
res = self._conn.query(
collection_name=self._collection_name,
expr=current_expr,
Expand Down Expand Up @@ -358,7 +380,7 @@ def close(self) -> None:
def inner_close():
self._cp_file_handler.close()
self._cp_file_path.unlink()
log.info(f"removed cp file:{self._cp_file_path_str} for query iterator")
LOGGER.info(f"removed cp file:{self._cp_file_path_str} for query iterator")

io_operation(
inner_close, f"failed to clear cp file:{self._cp_file_path_str} for query iterator"
Expand Down Expand Up @@ -482,14 +504,13 @@ def __init__(
self.__check_offset()
self.__check_rm_range_search_parameters()
self.__setup__pk_prop()
check_set_flag(self, "_print_iterator_cursor", self._kwargs, PRINT_ITERATOR_CURSOR)
self.__init_search_iterator()

def __init_search_iterator(self):
init_page = self.__execute_next_search(self._param, self._expr, False)
self._session_ts = init_page.get_session_ts()
if self._session_ts <= 0:
log.warning("failed to set up mvccTs from milvus server, use client-side ts instead")
LOGGER.warning("failed to set up mvccTs from milvus server, use client-side ts instead")
self._session_ts = fall_back_to_latest_session_ts()
self._kwargs[GUARANTEE_TIMESTAMP] = self._session_ts
if len(init_page) == 0:
Expand Down Expand Up @@ -693,8 +714,7 @@ def __try_search_fill(self) -> SearchPage:
def __execute_next_search(
self, next_params: dict, next_expr: str, to_extend_batch: bool
) -> SearchPage:
if self._print_iterator_cursor:
log.info(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}")
LOGGER.debug(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}")
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
Expand Down

0 comments on commit f16345d

Please sign in to comment.