Skip to content

Commit

Permalink
Support search by id
Browse files Browse the repository at this point in the history
Signed-off-by: Yicheng Hu <[email protected]>
  • Loading branch information
Jellal-HT committed Oct 19, 2021
1 parent fff6b96 commit b2587c1
Show file tree
Hide file tree
Showing 11 changed files with 440 additions and 116 deletions.
120 changes: 64 additions & 56 deletions grpc-proto/gen/milvus_pb2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions grpc-proto/gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def GetMetrics(self, request, context):
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')



def add_MilvusServiceServicer_to_server(servicer, server):
Expand Down
1 change: 1 addition & 0 deletions grpc-proto/milvus.proto
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ message SearchRequest {
repeated common.KeyValuePair search_params = 9; // must
uint64 travel_timestamp = 10;
uint64 guarantee_timestamp = 11; // guarantee_timestamp
schema.IDs searchIDs = 12; // search by ids
}

message Hits {
Expand Down
13 changes: 13 additions & 0 deletions pymilvus/client/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,16 @@ def is_legal_search_data(data):

return True

def is_legal_search_ids(ids):
if not isinstance(ids, list):
return False

for vector in ids:
if not isinstance(vector, int):
return False

return True


def is_legal_output_fields(output_fields):
if output_fields is None:
Expand Down Expand Up @@ -318,6 +328,9 @@ def check_pass_param(*args, **kwargs):
elif key in ("search_data",):
if not is_legal_search_data(value):
_raise_param_error(key, value)
elif key in ("search_ids",):
if not is_legal_search_ids(value):
_raise_param_error(key, value)
elif key in ("output_fields",):
if not is_legal_output_fields(value):
_raise_param_error(key, value)
Expand Down
18 changes: 18 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,24 @@ def search(self, collection_name, data, anns_field, param, limit,
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
return self._execute_search_requests(requests, timeout, **_kwargs)

@error_handler(None)
@check_has_collection
def search_by_id(self, collection_name, search_ids, anns_field, param, limit,
expression=None, partition_names=None, output_fields=None,
timeout=None, round_decimal=-1, **kwargs):
## do vector querying
_kwargs = copy.deepcopy(kwargs)
collection_schema = self.describe_collection(collection_name, timeout)
auto_id = collection_schema["auto_id"]
_kwargs["schema"] = collection_schema
requests = Prepare.search_requests_with_ids(collection_name, search_ids, anns_field, param, limit, expression,
partition_names, output_fields, round_decimal, **_kwargs)
_kwargs.pop("schema")
_kwargs["auto_id"] = auto_id
_kwargs["round_decimal"] = round_decimal
return self._execute_search_requests(requests, timeout, **_kwargs)


@error_handler(None)
def get_query_segment_infos(self, collection_name, timeout=30, **kwargs):
Expand Down
62 changes: 59 additions & 3 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def create_collection_request(cls, collection_name, fields, shards_num=2, **kwar
schema.fields.append(field_schema)

return milvus_types.CreateCollectionRequest(collection_name=collection_name,
schema=bytes(schema.SerializeToString()), shards_num = shards_num)
schema=bytes(schema.SerializeToString()), shards_num=shards_num)

@classmethod
def drop_collection_request(cls, collection_name):
Expand Down Expand Up @@ -381,7 +381,7 @@ def check_str(input, prefix):

# if partition_name is null or empty, delete action will apply to whole collection
partition_name = partition_name or ""
request = milvus_types.DeleteRequest(collection_name=collection_name, expr = expr, partition_name=partition_name)
request = milvus_types.DeleteRequest(collection_name=collection_name, expr=expr, partition_name=partition_name)
return request

@classmethod
Expand Down Expand Up @@ -508,7 +508,8 @@ def extract_vectors_param(param, placeholders, meta=None, names=None, round_deci
return requests

@classmethod
def search_request(cls, collection_name, query_entities, partition_names=None, fields=None, round_decimal=-1, **kwargs):
def search_request(cls, collection_name, query_entities, partition_names=None, fields=None, round_decimal=-1,
**kwargs):
schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc
Expand Down Expand Up @@ -659,6 +660,61 @@ def dump(v):

return requests

@classmethod
def search_requests_with_ids(cls, collection_name, search_ids, anns_field, param, limit, expr=None,
partition_names=None,
output_fields=None, round_decimal=-1, **kwargs):
schema = kwargs.get("schema", None)
fields_schema = schema.get("fields", None) # list
fields_name_locs = {fields_schema[loc]["name"]: loc
for loc in range(len(fields_schema))}

requests = []

if len(search_ids) <= 0:
return requests

nq = len(search_ids)
## TODO: add MaxSearchResultSize check

if anns_field not in fields_name_locs:
raise ParamError(f"Field {anns_field} doesn't exist in schema")

param_copy = copy.deepcopy(param)
metric_type = param_copy.pop("metric_type", "L2")
params = param_copy.pop("params", {})
if not isinstance(params, dict):
raise ParamError("Search params must be a dict")
search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params,
"round_decimal": round_decimal}

def dump(v):
if isinstance(v, dict):
return ujson.dumps(v)
return str(v)

request = milvus_types.SearchRequest(
collection_name=collection_name,
partition_names=partition_names,
output_fields=output_fields,
)

request.dsl_type = common_types.DslType.BoolExprV1
if expr is not None:
request.dsl = expr
request.search_params.extend([common_types.KeyValuePair(key=str(key), value=dump(value))
for key, value in search_params.items()])

# extract_search_ids
if (not isinstance(search_ids, list)) or len(search_ids) == 0 or not isinstance(search_ids[0], int):
raise ParamError("search ids array is empty or not a list or ids are not int type")

request.searchIDs.int_id.data.extend(search_ids)

requests.append(request)

return requests

@classmethod
def create_alias_request(cls, collection_name, alias):
return milvus_types.CreateAliasRequest(collection_name=collection_name, alias=alias)
Expand Down
55 changes: 55 additions & 0 deletions pymilvus/client/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,61 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search(collection_name, data, anns_field, param, limit, expression,
partition_names, output_fields, timeout, round_decimal, **kwargs)

@retry_on_rpc_failure(retry_times=10, wait=1)
def search_by_id(self, collection_name, search_ids, anns_field, param, limit, expression=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
"""
Searches a collection based on the given expression and returns query results.
:param collection_name: The name of the collection to search.
:type collection_name: str
:param search_ids: List of ids of the vectors to search, the length of search_ids is number of query (nq).
:type search_ids: list[int]
:param anns_field: The vector field used to search of collection.
:type anns_field: str
:param param: The parameters of search, such as nprobe, etc.
:type param: dict
:param limit: The max number of returned record, we also called this parameter as topk.
:type limit: int
:param expression: The boolean expression used to filter attribute.
:type expression: str
:param partition_names: The names of partitions to search.
:type partition_names: list[str]
:param output_fields: The fields to return in the search result, not supported now.
:type output_fields: list[str]
:param timeout: An optional duration of time in seconds to allow for the RPC. When timeout
is set to None, client waits until server response or error occur.
:type timeout: float
:param round_decimal: The specified number of decimal places of returned distance
:type round_decimal: int
:param kwargs:
* *_async* (``bool``) --
Indicate if invoke asynchronously. When value is true, method returns a SearchFuture object;
otherwise, method returns results from server.
* *_callback* (``function``) --
The callback function which is invoked after server response successfully. It only take
effect when _async is set to True.
:return: Query result. QueryResult is iterable and is a 2d-array-like class, the first dimension is
the number of vectors to query (nq), the second dimension is the number of limit(topk).
:rtype: QueryResult
:raises:
RpcError: If gRPC encounter an error
ParamError: If parameters are invalid
BaseException: If the return result from server is not ok
"""
check_pass_param(
limit=limit,
round_decimal=round_decimal,
anns_field=anns_field,
search_ids=search_ids,
partition_name_array=partition_names,
output_fields=output_fields,
)
with self._connection() as handler:
kwargs["_deploy_mode"] = self._deploy_mode
return handler.search_by_id(collection_name, search_ids, anns_field, param, limit, expression,
partition_names, output_fields, timeout, round_decimal, **kwargs)


@retry_on_rpc_failure(retry_times=10, wait=1)
def calc_distance(self, vectors_left, vectors_right, params=None, timeout=None, **kwargs):
Expand Down
120 changes: 64 additions & 56 deletions pymilvus/grpc_gen/milvus_pb2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymilvus/grpc_gen/milvus_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def GetMetrics(self, request, context):
raise NotImplementedError('Method not implemented!')



def add_MilvusServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'CreateCollection': grpc.unary_unary_rpc_method_handler(
Expand Down
86 changes: 85 additions & 1 deletion pymilvus/orm/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,90 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def search_by_id(self, search_ids, anns_field, param, limit, expr=None, partition_names=None,
output_fields=None, timeout=None, round_decimal=-1, **kwargs):
"""
Conducts a vector similarity search with an optional boolean expression as filter.
:param search_ids: List of ids of the vectors to search, the length of search_ids is number of query (nq).
:type search_ids: list[int]
:param anns_field: The vector field used to search of collection.
:type anns_field: str
:param param: The parameters of search, such as ``nprobe``.
:type param: dict
:param limit: The max number of returned record, also known as ``topk``.
:type limit: int
:param expr: The boolean expression used to filter attribute.
:type expr: str
:param partition_names: The names of partitions to search.
:type partition_names: list[str]
:param output_fields: The fields to return in the search result, not supported now.
:type output_fields: list[str]
:param timeout: An optional duration of time in seconds to allow for the RPC. When timeout
is set to None, client waits until server response or error occur.
:type timeout: float
:param round_decimal: The specified number of decimal places of returned distance
:type round_decimal: int
:param kwargs:
* *_async* (``bool``) --
Indicate if invoke asynchronously. When value is true, method returns a
SearchFuture object; otherwise, method returns results from server directly.
* *_callback* (``function``) --
The callback function which is invoked after server response successfully.
It functions only if _async is set to True.
:return: SearchResult:
SearchResult is iterable and is a 2d-array-like class, the first dimension is
the number of vectors to query (nq), the second dimension is the number of limit(topk).
:rtype: SearchResult
:raises RpcError: If gRPC encounter an error.
:raises ParamError: If parameters are invalid.
:raises DataTypeNotMatchException: If wrong type of param is passed.
:raises BaseException: If the return result from server is not ok.
:example:
>>> from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
>>> import random
>>> connections.connect()
<pymilvus.client.stub.Milvus object at 0x7f8579002dc0>
>>> schema = CollectionSchema([
... FieldSchema("film_id", DataType.INT64, is_primary=True),
... FieldSchema("films", dtype=DataType.FLOAT_VECTOR, dim=2)
... ])
>>> collection = Collection("test_collection_search", schema)
>>> # insert
>>> data = [
... [i for i in range(10)],
... [[random.random() for _ in range(2)] for _ in range(10)],
... ]
>>> collection.insert(data)
>>> collection.num_entities
10
>>> collection.load()
>>> # search
>>> search_param = {
... "search_ids": [1],
... "anns_field": "films",
... "param": {"metric_type": "L2"},
... "limit": 2,
... "expr": "film_id > 0",
... }
>>> res = collection.search_by_id(**search_param)
>>> assert len(res) == 1
>>> hits = res[0]
>>> assert len(hits) == 2
>>> print(f"- Total hits: {len(hits)}, hits ids: {hits.ids} ")
- Total hits: 2, hits ids: [1, 9]
>>> print(f"- Top1 hit id: {hits[0].id}, distance: {hits[0].distance}, score: {hits[0].score} ")
- Top1 hit id: 1, distance: 0.0, score: 0.0
"""
if expr is not None and not isinstance(expr, str):
raise DataTypeNotMatchException(0, ExceptionsMessage.ExprType % type(expr))

conn = self._get_connection()
res = conn.search_by_id(self._name, search_ids, anns_field, param, limit, expr,
partition_names, output_fields, timeout, round_decimal, **kwargs)
if kwargs.get("_async", False):
return SearchFuture(res)
return SearchResult(res)

def query(self, expr, output_fields=None, partition_names=None, timeout=None):
"""
Expand Down Expand Up @@ -1200,4 +1284,4 @@ def alter_alias(self, alias, timeout=None, **kwargs):
otherwise return Status(code=1, message='alias does not exist')
"""
conn = self._get_connection()
conn.alter_alias(self._name, alias, timeout=timeout, **kwargs)
conn.alter_alias(self._name, alias, timeout=timeout, **kwargs)
Loading

0 comments on commit b2587c1

Please sign in to comment.