diff --git a/pinecone/grpc/base.py b/pinecone/grpc/base.py index e1e26792..17580d7e 100644 --- a/pinecone/grpc/base.py +++ b/pinecone/grpc/base.py @@ -1,21 +1,14 @@ -import logging from abc import ABC, abstractmethod -from functools import wraps -from typing import Dict, Optional +from typing import Optional import grpc -from grpc._channel import _InactiveRpcError, Channel +from grpc._channel import Channel -from .retry import RetryConfig from .channel_factory import GrpcChannelFactory from pinecone import Config -from .utils import _generate_request_id from .config import GRPCClientConfig -from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION -from pinecone.exceptions.exceptions import PineconeException - -_logger = logging.getLogger(__name__) +from .grpc_runner import GrpcRunner class GRPCIndexBase(ABC): @@ -35,18 +28,12 @@ def __init__( ): self.config = config self.grpc_client_config = grpc_config or GRPCClientConfig() - self.retry_config = self.grpc_client_config.retry_config or RetryConfig() - - self.fixed_metadata = { - "api-key": config.api_key, - "service-name": index_name, - "client-version": CLIENT_VERSION, - } - if self.grpc_client_config.additional_metadata: - self.fixed_metadata.update(self.grpc_client_config.additional_metadata) self._endpoint_override = _endpoint_override + self.runner = GrpcRunner( + index_name=index_name, config=config, grpc_config=self.grpc_client_config + ) self.channel_factory = GrpcChannelFactory( config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False ) @@ -91,44 +78,6 @@ def close(self): except TypeError: pass - def _wrap_grpc_call( - self, - func, - request, - timeout=None, - metadata=None, - credentials=None, - wait_for_ready=None, - compression=None, - ): - @wraps(func) - def wrapped(): - user_provided_metadata = metadata or {} - _metadata = tuple( - (k, v) - for k, v in { - **self.fixed_metadata, - **self._request_metadata(), - **user_provided_metadata, - }.items() - ) - try: - return func( - request, - timeout=timeout, - metadata=_metadata, - credentials=credentials, - wait_for_ready=wait_for_ready, - compression=compression, - ) - except _InactiveRpcError as e: - raise PineconeException(e._state.debug_error_string) from e - - return wrapped() - - def _request_metadata(self) -> Dict[str, str]: - return {REQUEST_ID: _generate_request_id()} - def __enter__(self): return self diff --git a/pinecone/grpc/grpc_runner.py b/pinecone/grpc/grpc_runner.py new file mode 100644 index 00000000..253a6b33 --- /dev/null +++ b/pinecone/grpc/grpc_runner.py @@ -0,0 +1,97 @@ +from functools import wraps +from typing import Dict, Tuple, Optional + +from grpc._channel import _InactiveRpcError + +from pinecone import Config +from .utils import _generate_request_id +from .config import GRPCClientConfig +from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION +from pinecone.exceptions.exceptions import PineconeException +from grpc import CallCredentials, Compression +from google.protobuf.message import Message + + +class GrpcRunner: + def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig): + self.config = config + self.grpc_client_config = grpc_config + + self.fixed_metadata = { + "api-key": config.api_key, + "service-name": index_name, + "client-version": CLIENT_VERSION, + } + if self.grpc_client_config.additional_metadata: + self.fixed_metadata.update(self.grpc_client_config.additional_metadata) + + def run( + self, + func, + request: Message, + timeout: Optional[int] = None, + metadata: Optional[Dict[str, str]] = None, + credentials: Optional[CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[Compression] = None, + ): + @wraps(func) + def wrapped(): + user_provided_metadata = metadata or {} + _metadata = self._prepare_metadata(user_provided_metadata) + try: + return func( + request, + timeout=timeout, + metadata=_metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + except _InactiveRpcError as e: + raise PineconeException(e._state.debug_error_string) from e + + return wrapped() + + async def run_asyncio( + self, + func, + request: Message, + timeout: Optional[int] = None, + metadata: Optional[Dict[str, str]] = None, + credentials: Optional[CallCredentials] = None, + wait_for_ready: Optional[bool] = None, + compression: Optional[Compression] = None, + ): + @wraps(func) + async def wrapped(): + user_provided_metadata = metadata or {} + _metadata = self._prepare_metadata(user_provided_metadata) + try: + return await func( + request, + timeout=timeout, + metadata=_metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + except _InactiveRpcError as e: + raise PineconeException(e._state.debug_error_string) from e + + return await wrapped() + + def _prepare_metadata( + self, user_provided_metadata: Dict[str, str] + ) -> Tuple[Tuple[str, str], ...]: + return tuple( + (k, v) + for k, v in { + **self.fixed_metadata, + **self._request_metadata(), + **user_provided_metadata, + }.items() + ) + + def _request_metadata(self) -> Dict[str, str]: + return {REQUEST_ID: _generate_request_id()} diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 424fb576..6269c23d 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -133,7 +133,7 @@ def upsert( if async_req: args_dict = self._parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict, **kwargs) - future = self._wrap_grpc_call(self.stub.Upsert.future, request, timeout=timeout) + future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout) return PineconeGrpcFuture(future) if batch_size is None: @@ -155,15 +155,11 @@ def upsert( return UpsertResponse(upserted_count=total_upserted) def _upsert_batch( - self, - vectors: List[GRPCVector], - namespace: Optional[str], - timeout: Optional[float], - **kwargs, + self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs ) -> UpsertResponse: args_dict = self._parse_non_empty_args([("namespace", namespace)]) request = UpsertRequest(vectors=vectors, **args_dict) - return self._wrap_grpc_call(self.stub.Upsert, request, timeout=timeout, **kwargs) + return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs) def upsert_from_dataframe( self, @@ -280,10 +276,10 @@ def delete( request = DeleteRequest(**args_dict, **kwargs) if async_req: - future = self._wrap_grpc_call(self.stub.Delete.future, request, timeout=timeout) + future = self.runner.run(self.stub.Delete.future, request, timeout=timeout) return PineconeGrpcFuture(future) else: - return self._wrap_grpc_call(self.stub.Delete, request, timeout=timeout) + return self.runner.run(self.stub.Delete, request, timeout=timeout) def fetch( self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs @@ -308,7 +304,7 @@ def fetch( args_dict = self._parse_non_empty_args([("namespace", namespace)]) request = FetchRequest(ids=ids, **args_dict, **kwargs) - response = self._wrap_grpc_call(self.stub.Fetch, request, timeout=timeout) + response = self.runner.run(self.stub.Fetch, request, timeout=timeout) json_response = json_format.MessageToDict(response) return parse_fetch_response(json_response) @@ -388,7 +384,7 @@ def query( request = QueryRequest(**args_dict) timeout = kwargs.pop("timeout", None) - response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout) + response = self.runner.run(self.stub.Query, request, timeout=timeout) json_response = json_format.MessageToDict(response) return parse_query_response(json_response, _check_type=False) @@ -451,10 +447,10 @@ def update( request = UpdateRequest(id=id, **args_dict) if async_req: - future = self._wrap_grpc_call(self.stub.Update.future, request, timeout=timeout) + future = self.runner.run(self.stub.Update.future, request, timeout=timeout) return PineconeGrpcFuture(future) else: - return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout) + return self.runner.run(self.stub.Update, request, timeout=timeout) def list_paginated( self, @@ -499,7 +495,7 @@ def list_paginated( ) request = ListRequest(**args_dict, **kwargs) timeout = kwargs.pop("timeout", None) - response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout) + response = self.runner.run(self.stub.List, request, timeout=timeout) if response.pagination and response.pagination.next != "": pagination = Pagination(next=response.pagination.next) @@ -572,7 +568,7 @@ def describe_index_stats( timeout = kwargs.pop("timeout", None) request = DescribeIndexStatsRequest(**args_dict) - response = self._wrap_grpc_call(self.stub.DescribeIndexStats, request, timeout=timeout) + response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout) json_response = json_format.MessageToDict(response) return parse_stats_response(json_response) diff --git a/tests/unit_grpc/test_grpc_index_describe_index_stats.py b/tests/unit_grpc/test_grpc_index_describe_index_stats.py index 0b7388f4..24f8dcf7 100644 --- a/tests/unit_grpc/test_grpc_index_describe_index_stats.py +++ b/tests/unit_grpc/test_grpc_index_describe_index_stats.py @@ -12,16 +12,16 @@ def setup_method(self): ) def test_describeIndexStats_callWithoutFilter_CalledWithoutFilter(self, mocker): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.describe_index_stats() - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.DescribeIndexStats, DescribeIndexStatsRequest(), timeout=None ) def test_describeIndexStats_callWithFilter_CalledWithFilter(self, mocker, filter1): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.describe_index_stats(filter=filter1) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.DescribeIndexStats, DescribeIndexStatsRequest(filter=dict_to_proto_struct(filter1)), timeout=None, diff --git a/tests/unit_grpc/test_grpc_index_fetch.py b/tests/unit_grpc/test_grpc_index_fetch.py index 6ccb4199..620acfbb 100644 --- a/tests/unit_grpc/test_grpc_index_fetch.py +++ b/tests/unit_grpc/test_grpc_index_fetch.py @@ -11,15 +11,15 @@ def setup_method(self): ) def test_fetch_byIds_fetchByIds(self, mocker): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.fetch(["vec1", "vec2"]) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"]), timeout=None ) def test_fetch_byIdsAndNS_fetchByIdsAndNS(self, mocker): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.fetch(["vec1", "vec2"], namespace="ns", timeout=30) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"], namespace="ns"), timeout=30 ) diff --git a/tests/unit_grpc/test_grpc_index_initialization.py b/tests/unit_grpc/test_grpc_index_initialization.py index c12689ee..55fa5593 100644 --- a/tests/unit_grpc/test_grpc_index_initialization.py +++ b/tests/unit_grpc/test_grpc_index_initialization.py @@ -15,25 +15,6 @@ def test_init_with_default_config(self): assert index.grpc_client_config.grpc_channel_options is None assert index.grpc_client_config.additional_metadata is None - # Default metadata, grpc equivalent to http request headers - assert len(index.fixed_metadata) == 3 - assert index.fixed_metadata["api-key"] == "YOUR_API_KEY" - assert index.fixed_metadata["service-name"] == "my-index" - assert index.fixed_metadata["client-version"] is not None - - def test_init_with_additional_metadata(self): - pc = PineconeGRPC(api_key="YOUR_API_KEY") - config = GRPCClientConfig( - additional_metadata={"debug-header": "value123", "debug-header2": "value456"} - ) - index = pc.Index(name="my-index", host="host", grpc_config=config) - assert len(index.fixed_metadata) == 5 - assert index.fixed_metadata["api-key"] == "YOUR_API_KEY" - assert index.fixed_metadata["service-name"] == "my-index" - assert index.fixed_metadata["client-version"] is not None - assert index.fixed_metadata["debug-header"] == "value123" - assert index.fixed_metadata["debug-header2"] == "value456" - def test_init_with_grpc_config_from_dict(self): pc = PineconeGRPC(api_key="YOUR_API_KEY") config = GRPCClientConfig._from_dict({"timeout": 10}) diff --git a/tests/unit_grpc/test_grpc_index_query.py b/tests/unit_grpc/test_grpc_index_query.py index a871656c..86e0f6a2 100644 --- a/tests/unit_grpc/test_grpc_index_query.py +++ b/tests/unit_grpc/test_grpc_index_query.py @@ -14,16 +14,16 @@ def setup_method(self): ) def test_query_byVectorNoFilter_queryVectorNoFilter(self, mocker, vals1): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.query(top_k=10, vector=vals1) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Query, QueryRequest(top_k=10, vector=vals1), timeout=None ) def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, filter1): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.query(top_k=10, vector=vals1, filter=filter1, namespace="ns", timeout=10) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Query, QueryRequest( top_k=10, vector=vals1, filter=dict_to_proto_struct(filter1), namespace="ns" @@ -32,9 +32,9 @@ def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, fil ) def test_query_byVecId_queryByVecId(self, mocker): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.query(top_k=10, id="vec1", include_metadata=True, include_values=False) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Query, QueryRequest(top_k=10, id="vec1", include_metadata=True, include_values=False), timeout=None, diff --git a/tests/unit_grpc/test_grpc_index_update.py b/tests/unit_grpc/test_grpc_index_update.py index 12774195..978e06aa 100644 --- a/tests/unit_grpc/test_grpc_index_update.py +++ b/tests/unit_grpc/test_grpc_index_update.py @@ -12,18 +12,18 @@ def setup_method(self): ) def test_update_byIdAnValues_updateByIdAndValues(self, mocker, vals1): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.update(id="vec1", values=vals1, namespace="ns", timeout=30) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Update, UpdateRequest(id="vec1", values=vals1, namespace="ns"), timeout=30, ) def test_update_byIdAnValuesAsync_updateByIdAndValuesAsync(self, mocker, vals1): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.update(id="vec1", values=vals1, namespace="ns", timeout=30, async_req=True) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Update.future, UpdateRequest(id="vec1", values=vals1, namespace="ns"), timeout=30, @@ -32,9 +32,9 @@ def test_update_byIdAnValuesAsync_updateByIdAndValuesAsync(self, mocker, vals1): def test_update_byIdAnValuesAndMetadata_updateByIdAndValuesAndMetadata( self, mocker, vals1, md1 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.update("vec1", values=vals1, set_metadata=md1) - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Update, UpdateRequest(id="vec1", values=vals1, set_metadata=dict_to_proto_struct(md1)), timeout=None, diff --git a/tests/unit_grpc/test_grpc_index_upsert.py b/tests/unit_grpc/test_grpc_index_upsert.py index fb5de75a..cd65b7de 100644 --- a/tests/unit_grpc/test_grpc_index_upsert.py +++ b/tests/unit_grpc/test_grpc_index_upsert.py @@ -71,7 +71,7 @@ def setup_method(self): ) def _assert_called_once(self, vectors, async_call=False): - self.index._wrap_grpc_call.assert_called_once_with( + self.index.runner.run.assert_called_once_with( self.index.stub.Upsert.future if async_call else self.index.stub.Upsert, UpsertRequest(vectors=vectors, namespace="ns"), timeout=None, @@ -80,19 +80,19 @@ def _assert_called_once(self, vectors, async_call=False): def test_upsert_tuplesOfIdVec_UpserWithoutMD( self, mocker, vals1, vals2, expected_vec1, expected_vec2 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.upsert([("vec1", vals1), ("vec2", vals2)], namespace="ns") self._assert_called_once([expected_vec1, expected_vec2]) def test_upsert_tuplesOfIdVecMD_UpsertVectorsWithMD( self, mocker, vals1, md1, vals2, md2, expected_vec_md1, expected_vec_md2 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.upsert([("vec1", vals1, md1), ("vec2", vals2, md2)], namespace="ns") self._assert_called_once([expected_vec_md1, expected_vec_md2]) def test_upsert_vectors_upsertInputVectors(self, mocker, expected_vec_md1, expected_vec_md2): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.upsert([expected_vec_md1, expected_vec_md2], namespace="ns") self._assert_called_once([expected_vec_md1, expected_vec_md2]) @@ -110,7 +110,7 @@ def test_upsert_vectors_upsertInputVectorsSparse( expected_vec_md_sparse1, expected_vec_md_sparse2, ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.upsert( [ Vector( @@ -131,7 +131,7 @@ def test_upsert_vectors_upsertInputVectorsSparse( self._assert_called_once([expected_vec_md_sparse1, expected_vec_md_sparse2]) def test_upsert_dict(self, mocker, vals1, vals2, expected_vec1, expected_vec2): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) dict1 = {"id": "vec1", "values": vals1} dict2 = {"id": "vec2", "values": vals2} self.index.upsert([dict1, dict2], namespace="ns") @@ -140,7 +140,7 @@ def test_upsert_dict(self, mocker, vals1, vals2, expected_vec1, expected_vec2): def test_upsert_dict_md( self, mocker, vals1, md1, vals2, md2, expected_vec_md1, expected_vec_md2 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) dict1 = {"id": "vec1", "values": vals1, "metadata": md1} dict2 = {"id": "vec2", "values": vals2, "metadata": md2} self.index.upsert([dict1, dict2], namespace="ns") @@ -156,7 +156,7 @@ def test_upsert_dict_sparse( sparse_indices_2, sparse_values_2, ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) dict1 = { "id": "vec1", "values": vals1, @@ -197,7 +197,7 @@ def test_upsert_dict_sparse_md( sparse_indices_2, sparse_values_2, ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) dict1 = { "id": "vec1", "values": vals1, @@ -229,7 +229,7 @@ def test_upsert_dict_sparse_md( ) def test_upsert_dict_negative(self, mocker, vals1, vals2, md2): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) # Missing required keys dict1 = {"values": vals1} @@ -269,7 +269,7 @@ def test_upsert_dict_negative(self, mocker, vals1, vals2, md2): def test_upsert_dict_with_invalid_values( self, mocker, key, new_val, vals1, md1, sparse_indices_1, sparse_values_1 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) full_dict1 = { "id": "vec1", @@ -288,7 +288,7 @@ def test_upsert_dict_with_invalid_values( def test_upsert_dict_with_invalid_ids( self, mocker, key, new_val, vals1, md1, sparse_indices_1, sparse_values_1 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) full_dict1 = { "id": "vec1", @@ -310,7 +310,7 @@ def test_upsert_dict_with_invalid_ids( def test_upsert_dict_with_invalid_sparse_values( self, mocker, key, new_val, vals1, md1, sparse_indices_1, sparse_values_1 ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) full_dict1 = { "id": "vec1", @@ -341,8 +341,8 @@ def test_upsert_dataframe( expected_vec_md_sparse2, ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: MockUpsertDelegate( UpsertResponse(upserted_count=len(upsert_request.vectors)) @@ -384,8 +384,8 @@ def test_upsert_dataframe_sync( expected_vec_md_sparse2, ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: UpsertResponse( upserted_count=len(upsert_request.vectors) @@ -424,7 +424,7 @@ def test_upsert_dataframe_negative( sparse_indices_2, sparse_values_2, ): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) full_dict1 = { "id": "vec1", "values": vals1, @@ -457,7 +457,7 @@ def test_upsert_dataframe_negative( self.index.upsert_from_dataframe(df) def test_upsert_async_upsertInputVectorsAsync(self, mocker, expected_vec_md1, expected_vec_md2): - mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True) + mocker.patch.object(self.index.runner, "run", autospec=True) self.index.upsert([expected_vec_md1, expected_vec_md2], namespace="ns", async_req=True) self._assert_called_once([expected_vec_md1, expected_vec_md2], async_call=True) @@ -465,8 +465,8 @@ def test_upsert_vectorListIsMultiplyOfBatchSize_vectorsUpsertedInBatches( self, mocker, vals1, md1, expected_vec_md1, expected_vec_md2 ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: UpsertResponse( upserted_count=len(upsert_request.vectors) @@ -476,7 +476,7 @@ def test_upsert_vectorListIsMultiplyOfBatchSize_vectorsUpsertedInBatches( result = self.index.upsert( [expected_vec_md1, expected_vec_md2], namespace="ns", batch_size=1, show_progress=False ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest( vectors=[Vector(id="vec1", values=vals1, metadata=dict_to_proto_struct(md1))], @@ -485,7 +485,7 @@ def test_upsert_vectorListIsMultiplyOfBatchSize_vectorsUpsertedInBatches( timeout=None, ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest(vectors=[expected_vec_md2], namespace="ns"), timeout=None, @@ -497,8 +497,8 @@ def test_upsert_vectorListNotMultiplyOfBatchSize_vectorsUpsertedInBatches( self, mocker, vals1, vals2, md1, md2, expected_vec_md1, expected_vec_md2 ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: UpsertResponse( upserted_count=len(upsert_request.vectors) @@ -514,13 +514,13 @@ def test_upsert_vectorListNotMultiplyOfBatchSize_vectorsUpsertedInBatches( namespace="ns", batch_size=2, ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest(vectors=[expected_vec_md1, expected_vec_md2], namespace="ns"), timeout=None, ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest( vectors=[Vector(id="vec3", values=vals1, metadata=dict_to_proto_struct(md1))], @@ -535,8 +535,8 @@ def test_upsert_vectorListSmallerThanBatchSize_vectorsUpsertedInBatches( self, mocker, expected_vec_md1, expected_vec_md2 ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: UpsertResponse( upserted_count=len(upsert_request.vectors) @@ -554,8 +554,8 @@ def test_upsert_tuplesList_vectorsUpsertedInBatches( self, mocker, vals1, md1, vals2, md2, expected_vec_md1, expected_vec_md2 ): mocker.patch.object( - self.index, - "_wrap_grpc_call", + self.index.runner, + "run", autospec=True, side_effect=lambda stub, upsert_request, timeout: UpsertResponse( upserted_count=len(upsert_request.vectors) @@ -567,13 +567,13 @@ def test_upsert_tuplesList_vectorsUpsertedInBatches( namespace="ns", batch_size=2, ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest(vectors=[expected_vec_md1, expected_vec_md2], namespace="ns"), timeout=None, ) - self.index._wrap_grpc_call.assert_any_call( + self.index.runner.run.assert_any_call( self.index.stub.Upsert, UpsertRequest( vectors=[Vector(id="vec3", values=vals1, metadata=dict_to_proto_struct(md1))], diff --git a/tests/unit_grpc/test_runner.py b/tests/unit_grpc/test_runner.py new file mode 100644 index 00000000..7a7670c8 --- /dev/null +++ b/tests/unit_grpc/test_runner.py @@ -0,0 +1,120 @@ +import uuid + +from grpc import Compression + +from pinecone.config import Config +from pinecone.grpc.config import GRPCClientConfig +from pinecone.grpc.grpc_runner import GrpcRunner +from pinecone.utils.constants import CLIENT_VERSION + + +class TestGrpcRunner: + def test_run_with_default_metadata(self, mocker): + config = Config(api_key="YOUR_API_KEY") + runner = GrpcRunner(index_name="my-index", config=config, grpc_config=GRPCClientConfig()) + + mock_func = mocker.Mock() + runner.run(mock_func, request="request") + + passed_metadata = mock_func.call_args.kwargs["metadata"] + # Fixed metadata fields + assert ("api-key", "YOUR_API_KEY") in passed_metadata + assert ("service-name", "my-index") in passed_metadata + assert ("client-version", CLIENT_VERSION) in passed_metadata + + # Request id assigned for each request + assert any( + item[0] == "request_id" for item in passed_metadata + ), "request_id not found in metadata" + for items in passed_metadata: + if items[0] == "request_id": + assert isinstance(items[1], str) + assert uuid.UUID(items[1], version=4), "request_id is not a valid UUID" + + def test_each_run_gets_unique_request_id(self, mocker): + config = Config(api_key="YOUR_API_KEY") + runner = GrpcRunner(index_name="my-index", config=config, grpc_config=GRPCClientConfig()) + + mock_func = mocker.Mock() + runner.run(mock_func, request="request") + + for items in mock_func.call_args.kwargs["metadata"]: + if items[0] == "request_id": + first_request_id = items[1] + + mock_func.reset_mock() + runner.run(mock_func, request="request") + for items in mock_func.call_args.kwargs["metadata"]: + if items[0] == "request_id": + second_request_id = items[1] + assert ( + second_request_id != first_request_id + ), "request_id is not unique for each request" + + def test_run_with_additional_metadata_from_grpc_config(self, mocker): + config = Config(api_key="YOUR_API_KEY") + grpc_config = GRPCClientConfig( + additional_metadata={"debug-header": "value123", "debug-header2": "value456"} + ) + runner = GrpcRunner(index_name="my-index", config=config, grpc_config=grpc_config) + + mock_func = mocker.Mock() + runner.run(mock_func, request="request") + + passed_metadata = mock_func.call_args.kwargs["metadata"] + assert ("api-key", "YOUR_API_KEY") in passed_metadata + assert ("service-name", "my-index") in passed_metadata + assert ("client-version", CLIENT_VERSION) in passed_metadata + assert ("debug-header", "value123") in passed_metadata + assert ("debug-header2", "value456") in passed_metadata + + def test_with_additional_metadata_from_run(self, mocker): + config = Config(api_key="YOUR_API_KEY") + grpc_config = GRPCClientConfig( + additional_metadata={"debug-header": "value123", "debug-header2": "value456"} + ) + runner = GrpcRunner(index_name="my-index", config=config, grpc_config=grpc_config) + + mock_func = mocker.Mock() + runner.run( + mock_func, + request="request", + metadata={"user-extra": "extra-value", "user-extra2": "extra-value2"}, + ) + + passed_metadata = mock_func.call_args.kwargs["metadata"] + + # Fixed metadata fields + assert ("api-key", "YOUR_API_KEY") in passed_metadata + assert ("service-name", "my-index") in passed_metadata + assert ("client-version", CLIENT_VERSION) in passed_metadata + # Request id + assert any( + item[0] == "request_id" for item in passed_metadata + ), "request_id not found in metadata" + # Extras from configuration + assert ("debug-header", "value123") in passed_metadata + assert ("debug-header2", "value456") in passed_metadata + # Extras from call to run() + assert ("user-extra", "extra-value") in passed_metadata + assert ("user-extra2", "extra-value2") in passed_metadata + + def test_run_with_other_args(self, mocker): + config = Config(api_key="YOUR_API_KEY") + grpc_config = GRPCClientConfig( + additional_metadata={"debug-header": "value123", "debug-header2": "value456"} + ) + runner = GrpcRunner(index_name="my-index", config=config, grpc_config=grpc_config) + + mock_func = mocker.Mock() + runner.run( + mock_func, + request="request", + timeout=10, + wait_for_ready=True, + compression=Compression.Gzip, + ) + + assert mock_func.call_args.kwargs["timeout"] == 10 + assert mock_func.call_args.kwargs["wait_for_ready"] == True + assert mock_func.call_args.kwargs["compression"] == Compression.Gzip