From d2841a3c69e3724b78fc5f8376d6370a0ba573f3 Mon Sep 17 00:00:00 2001 From: Yicheng Hu <75612148+Jellal-HT@users.noreply.github.com> Date: Mon, 11 Oct 2021 18:58:18 -0600 Subject: [PATCH] Support search by id Signed-off-by: Jellal-HT --- grpc-proto/gen/milvus_pb2.py | 9 +++++++ grpc-proto/gen/milvus_pb2_grpc.py | 32 +++++++++++++++++++++++ grpc-proto/milvus.proto | 1 + pymilvus/client/grpc_handler.py | 38 ++++++++++++++++++++++++++++ pymilvus/client/stub.py | 9 +++++++ pymilvus/grpc_gen/milvus_pb2.py | 9 +++++++ pymilvus/grpc_gen/milvus_pb2_grpc.py | 33 ++++++++++++++++++++++++ pymilvus/orm/collection.py | 16 +++++++++++- pymilvus/orm/partition.py | 10 ++++++++ 9 files changed, 156 insertions(+), 1 deletion(-) diff --git a/grpc-proto/gen/milvus_pb2.py b/grpc-proto/gen/milvus_pb2.py index dc4e880c9..80061055c 100644 --- a/grpc-proto/gen/milvus_pb2.py +++ b/grpc-proto/gen/milvus_pb2.py @@ -4546,6 +4546,15 @@ serialized_options=None, create_key=_descriptor._internal_create_key, ), + _descriptor.MethodDescriptor( + name='GetVectorsByID', + full_name='milvus.proto.milvus.MilvusService.GetVectorsByID', + index=34, + containing_service=None, + input_type=_VECTORIDS, + output_type=_VECTORSARRAY, + serialized_options=None, + ), ]) _sym_db.RegisterServiceDescriptor(_MILVUSSERVICE) diff --git a/grpc-proto/gen/milvus_pb2_grpc.py b/grpc-proto/gen/milvus_pb2_grpc.py index 5824191c0..63195a672 100644 --- a/grpc-proto/gen/milvus_pb2_grpc.py +++ b/grpc-proto/gen/milvus_pb2_grpc.py @@ -185,6 +185,11 @@ def __init__(self, channel): request_serializer=milvus__pb2.GetMetricsRequest.SerializeToString, response_deserializer=milvus__pb2.GetMetricsResponse.FromString, ) + self.GetVectorsByID = channel.unary_unary( + '/milvus.proto.milvus.MilvusService/GetVectorsByID', + request_serializer=milvus__pb2.VectorIDs.SerializeToString, + response_deserializer=milvus__pb2.VectorsArray.FromString, + ) class MilvusServiceServicer(object): @@ -395,6 +400,11 @@ def GetMetrics(self, request, context): context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + + def GetVectorsByID(self, request, context): + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def add_MilvusServiceServicer_to_server(servicer, server): @@ -569,6 +579,11 @@ def add_MilvusServiceServicer_to_server(servicer, server): request_deserializer=milvus__pb2.GetMetricsRequest.FromString, response_serializer=milvus__pb2.GetMetricsResponse.SerializeToString, ), + 'GetVectorsByID': grpc.unary_unary_rpc_method_handler( + servicer.GetVectorsByID, + request_deserializer=milvus__pb2.VectorIDs.FromString, + response_serializer=milvus__pb2.VectorsArray.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'milvus.proto.milvus.MilvusService', rpc_method_handlers) @@ -1156,6 +1171,23 @@ def GetMetrics(request, milvus__pb2.GetMetricsResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GetVectorsByID(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/GetVectorsByID', + milvus__pb2.GetMetricsRequest.SerializeToString, + milvus__pb2.GetMetricsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) class ProxyServiceStub(object): diff --git a/grpc-proto/milvus.proto b/grpc-proto/milvus.proto index a6c328007..59ce9891c 100644 --- a/grpc-proto/milvus.proto +++ b/grpc-proto/milvus.proto @@ -40,6 +40,7 @@ service MilvusService { rpc Flush(FlushRequest) returns (FlushResponse) {} rpc Query(QueryRequest) returns (QueryResults) {} rpc CalcDistance(CalcDistanceRequest) returns (CalcDistanceResults) {} + rpc GetVectorsByID(VectorIDs) returns (VectorsArray) {} rpc GetPersistentSegmentInfo(GetPersistentSegmentInfoRequest) returns (GetPersistentSegmentInfoResponse) {} rpc GetQuerySegmentInfo(GetQuerySegmentInfoRequest) returns (GetQuerySegmentInfoResponse) {} diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index c739d4aa0..6527e5074 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -621,6 +621,44 @@ def search(self, collection_name, data, anns_field, param, limit, _kwargs["auto_id"] = auto_id _kwargs["round_decimal"] = round_decimal return self._execute_search_requests(requests, timeout, **_kwargs) + + def search_by_id(self, collection_name, query_id, anns_field, param, limit, + expression = None, partition_tags=None, output_fields = None, + timeout=None, round_decimal=-1, **kwargs): + + ## first part of the method: get the vector by id + rf = self._stub.HasCollection.future(Prepare.has_collection_request(collection_name), wait_for_ready=True, + timeout=timeout) + reply = rf.result() + if reply.status.error_code != 0 or not reply.value: + raise CollectionNotExistException(reply.status.error_code, "collection not exists") + + request = milvus_types.VectorIDs(collection_name=collection_name, field_name = None, id_array=query_id, + partition_names=partition_tags) + + future = self._stub.GetVectorsByID.future(request, wait_for_ready=True, timeout=timeout) + response = future.result() + ## variable that stores the vector corresponding to the id + vector = list() + if response.data_array == None: + print("can not obtain vector") + return + else: + for datas in response.data_array: + data = bytes(datas.binary_data) or list(datas.float_data) + ## vector corresponding to the query_id + vector.append(data) + + ## second part of the method: do vector querying + _kwargs = copy.deepcopy(kwargs) + schema = self.self.describe_collection(collection_name, timeout) + _kwargs["schema"] = schema + _kwargs["auto_id"] = schema["auto_id"] + _kwargs["round_decimal"] = round_decimal + requests = Prepare.search_requests_with_expr(collection_name, vector, anns_field, param, limit, expression, + partition_tags, output_fields, round_decimal, **_kwargs) + return self._execute_search_requests(requests, timeout, **_kwargs) + @error_handler(None) def get_query_segment_infos(self, collection_name, timeout=30, **kwargs): diff --git a/pymilvus/client/stub.py b/pymilvus/client/stub.py index ebcc7fe91..f865c096a 100644 --- a/pymilvus/client/stub.py +++ b/pymilvus/client/stub.py @@ -1066,6 +1066,15 @@ def search(self, collection_name, data, anns_field, param, limit, expression=Non kwargs["_deploy_mode"] = self._deploy_mode return handler.search(collection_name, data, anns_field, param, limit, expression, partition_names, output_fields, timeout, round_decimal, **kwargs) + + def search_by_id(self, collection_name, query_id, anns_field, param, limit, + expression = None, partition_tags=None, output_fields = None, + timeout=None, round_decimal=-1, **kwargs): + with self._connection() as handler: + kwargs["_deploy_mode"] = self._deploy_mode + return handler.search_by_id(collection_name, query_id, anns_field, param, limit, expression, + partition_tags, output_fields, timeout, round_decimal, **kwargs) + @retry_on_rpc_failure(retry_times=10, wait=1) def calc_distance(self, vectors_left, vectors_right, params=None, timeout=None, **kwargs): diff --git a/pymilvus/grpc_gen/milvus_pb2.py b/pymilvus/grpc_gen/milvus_pb2.py index d6ab858f7..501a5d51e 100644 --- a/pymilvus/grpc_gen/milvus_pb2.py +++ b/pymilvus/grpc_gen/milvus_pb2.py @@ -4546,6 +4546,15 @@ serialized_options=None, create_key=_descriptor._internal_create_key, ), + _descriptor.MethodDescriptor( + name='GetVectorsByID', + full_name='milvus.proto.milvus.MilvusService.GetVectorsByID', + index=34, + containing_service=None, + input_type=_VECTORIDS, + output_type=_VECTORSARRAY, + serialized_options=None, + ), ]) _sym_db.RegisterServiceDescriptor(_MILVUSSERVICE) diff --git a/pymilvus/grpc_gen/milvus_pb2_grpc.py b/pymilvus/grpc_gen/milvus_pb2_grpc.py index bf9c3bf1f..8fc4607b8 100644 --- a/pymilvus/grpc_gen/milvus_pb2_grpc.py +++ b/pymilvus/grpc_gen/milvus_pb2_grpc.py @@ -185,6 +185,11 @@ def __init__(self, channel): request_serializer=milvus__pb2.GetMetricsRequest.SerializeToString, response_deserializer=milvus__pb2.GetMetricsResponse.FromString, ) + self.GetVectorsByID = channel.unary_unary( + '/milvus.proto.milvus.MilvusService/GetVectorsByID', + request_serializer=milvus__pb2.VectorIDs.SerializeToString, + response_deserializer=milvus__pb2.VectorsArray.FromString, + ) class MilvusServiceServicer(object): @@ -395,6 +400,12 @@ def GetMetrics(self, request, context): context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + + def GetVectorsByID(self, request, context): + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_MilvusServiceServicer_to_server(servicer, server): @@ -569,6 +580,11 @@ def add_MilvusServiceServicer_to_server(servicer, server): request_deserializer=milvus__pb2.GetMetricsRequest.FromString, response_serializer=milvus__pb2.GetMetricsResponse.SerializeToString, ), + 'GetVectorsByID': grpc.unary_unary_rpc_method_handler( + servicer.GetVectorsByID, + request_deserializer=milvus__pb2.VectorIDs.FromString, + response_serializer=milvus__pb2.VectorsArray.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'milvus.proto.milvus.MilvusService', rpc_method_handlers) @@ -1156,6 +1172,23 @@ def GetMetrics(request, milvus__pb2.GetMetricsResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + def GetVectorsByID(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/milvus.proto.milvus.MilvusService/GetVectorsByID', + milvus__pb2.GetMetricsRequest.SerializeToString, + milvus__pb2.GetMetricsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + class ProxyServiceStub(object): diff --git a/pymilvus/orm/collection.py b/pymilvus/orm/collection.py index 7cffc908c..1398450c7 100644 --- a/pymilvus/orm/collection.py +++ b/pymilvus/orm/collection.py @@ -672,6 +672,20 @@ def search(self, data, anns_field, param, limit, expr=None, partition_names=None if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) + + def searchByID(self, query_id, anns_field, param, limit, expr=None, partition_names=None, + output_fields=None, timeout=None, round_decimal=-1, **kwargs): + + if expr is not None and not isinstance(expr, str): + raise DataTypeNotMatchException(0, ExceptionsMessage.ExprType % type(expr)) + + conn = self._get_connection() + res = conn.search_by_id(self._name, query_id, anns_field, param, limit, expr, + partition_names, output_fields, timeout, round_decimal, **kwargs) + if kwargs.get("_async", False): + return SearchFuture(res) + return SearchResult(res) + def query(self, expr, output_fields=None, partition_names=None, timeout=None): """ @@ -1200,4 +1214,4 @@ def alter_alias(self, alias, timeout=None, **kwargs): otherwise return Status(code=1, message='alias does not exist') """ conn = self._get_connection() - conn.alter_alias(self._name, alias, timeout=timeout, **kwargs) \ No newline at end of file + conn.alter_alias(self._name, alias, timeout=timeout, **kwargs) diff --git a/pymilvus/orm/partition.py b/pymilvus/orm/partition.py index 7ec3f20d9..f7a906ae8 100644 --- a/pymilvus/orm/partition.py +++ b/pymilvus/orm/partition.py @@ -407,6 +407,16 @@ def search(self, data, anns_field, param, limit, expr=None, output_fields=None, if kwargs.get("_async", False): return SearchFuture(res) return SearchResult(res) + + def searchByID(self, query_id, anns_field, param, limit, expr=None, output_fields=None, timeout=None, round_decimal=-1, + **kwargs): + conn = self._get_connection() + res = conn.search_by_id(self._collection.name, query_id, anns_field, param, limit, + expr, [self._name], output_fields, timeout, round_decimal, **kwargs) + if kwargs.get("_async", False): + return SearchFuture(res) + return SearchResult(res) + def query(self, expr, output_fields=None, timeout=None): """