Skip to content

Commit

Permalink
[Refactor] Extract GrpcRunner from GRPCIndexBase class (#395)
Browse files Browse the repository at this point in the history
## Problem

This is another extractive refactoring in preparation for grpc with
asyncio.

## Solution

The generated stub class, `VectorServiceStub`, is what knows how to call
the Pinecone grpc service, but our wrapper code needs to do some work to
make sure we have a consistent approach to "metadata" (grpc-speak for
request headers) and handling other request params like `timeout`.
Previously this work was accomplished in a private method of the
`GRPCIndexBase` base class called `_wrap_grpc_call()`.

Since we will need to perform almost identical marshaling of metadata
for requests with asyncio, I pulled this logic out into a separate class
`GrpcRunner` and renamed `_wrap_grpc_call` to `run`. You can see there
is also a parallel method implementation called `run_asyncio`; currently
this is unused and untested, but kind of illustrates why this refactor
is useful.

## Type of Change

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [x] None of the above: Mechanical refactor, should have no net impact
to functionality.

## Test Plan

Tests should still be green
  • Loading branch information
jhamon authored Oct 15, 2024
1 parent 4c18899 commit 2831a5e
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 144 deletions.
63 changes: 6 additions & 57 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import logging
from abc import ABC, abstractmethod
from functools import wraps
from typing import Dict, Optional
from typing import Optional

import grpc
from grpc._channel import _InactiveRpcError, Channel
from grpc._channel import Channel

from .retry import RetryConfig
from .channel_factory import GrpcChannelFactory

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException

_logger = logging.getLogger(__name__)
from .grpc_runner import GrpcRunner


class GRPCIndexBase(ABC):
Expand All @@ -35,18 +28,12 @@ def __init__(
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

self._endpoint_override = _endpoint_override

self.runner = GrpcRunner(
index_name=index_name, config=config, grpc_config=self.grpc_client_config
)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
)
Expand Down Expand Up @@ -91,44 +78,6 @@ def close(self):
except TypeError:
pass

def _wrap_grpc_call(
self,
func,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}

def __enter__(self):
return self

Expand Down
97 changes: 97 additions & 0 deletions pinecone/grpc/grpc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from functools import wraps
from typing import Dict, Tuple, Optional

from grpc._channel import _InactiveRpcError

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException
from grpc import CallCredentials, Compression
from google.protobuf.message import Message


class GrpcRunner:
def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig):
self.config = config
self.grpc_client_config = grpc_config

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

def run(
self,
func,
request: Message,
timeout: Optional[int] = None,
metadata: Optional[Dict[str, str]] = None,
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

async def run_asyncio(
self,
func,
request: Message,
timeout: Optional[int] = None,
metadata: Optional[Dict[str, str]] = None,
credentials: Optional[CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[Compression] = None,
):
@wraps(func)
async def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return await wrapped()

def _prepare_metadata(
self, user_provided_metadata: Dict[str, str]
) -> Tuple[Tuple[str, str], ...]:
return tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}
26 changes: 11 additions & 15 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def upsert(
if async_req:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
future = self._wrap_grpc_call(self.stub.Upsert.future, request, timeout=timeout)
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
return PineconeGrpcFuture(future)

if batch_size is None:
Expand All @@ -155,15 +155,11 @@ def upsert(
return UpsertResponse(upserted_count=total_upserted)

def _upsert_batch(
self,
vectors: List[GRPCVector],
namespace: Optional[str],
timeout: Optional[float],
**kwargs,
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
) -> UpsertResponse:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return self._wrap_grpc_call(self.stub.Upsert, request, timeout=timeout, **kwargs)
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)

def upsert_from_dataframe(
self,
Expand Down Expand Up @@ -280,10 +276,10 @@ def delete(

request = DeleteRequest(**args_dict, **kwargs)
if async_req:
future = self._wrap_grpc_call(self.stub.Delete.future, request, timeout=timeout)
future = self.runner.run(self.stub.Delete.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Delete, request, timeout=timeout)
return self.runner.run(self.stub.Delete, request, timeout=timeout)

def fetch(
self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs
Expand All @@ -308,7 +304,7 @@ def fetch(
args_dict = self._parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self._wrap_grpc_call(self.stub.Fetch, request, timeout=timeout)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_fetch_response(json_response)

Expand Down Expand Up @@ -388,7 +384,7 @@ def query(
request = QueryRequest(**args_dict)

timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout)
response = self.runner.run(self.stub.Query, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, _check_type=False)

Expand Down Expand Up @@ -451,10 +447,10 @@ def update(

request = UpdateRequest(id=id, **args_dict)
if async_req:
future = self._wrap_grpc_call(self.stub.Update.future, request, timeout=timeout)
future = self.runner.run(self.stub.Update.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout)
return self.runner.run(self.stub.Update, request, timeout=timeout)

def list_paginated(
self,
Expand Down Expand Up @@ -499,7 +495,7 @@ def list_paginated(
)
request = ListRequest(**args_dict, **kwargs)
timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout)
response = self.runner.run(self.stub.List, request, timeout=timeout)

if response.pagination and response.pagination.next != "":
pagination = Pagination(next=response.pagination.next)
Expand Down Expand Up @@ -572,7 +568,7 @@ def describe_index_stats(
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
response = self._wrap_grpc_call(self.stub.DescribeIndexStats, request, timeout=timeout)
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_stats_response(json_response)

Expand Down
8 changes: 4 additions & 4 deletions tests/unit_grpc/test_grpc_index_describe_index_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ def setup_method(self):
)

def test_describeIndexStats_callWithoutFilter_CalledWithoutFilter(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.describe_index_stats()
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.DescribeIndexStats, DescribeIndexStatsRequest(), timeout=None
)

def test_describeIndexStats_callWithFilter_CalledWithFilter(self, mocker, filter1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.describe_index_stats(filter=filter1)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.DescribeIndexStats,
DescribeIndexStatsRequest(filter=dict_to_proto_struct(filter1)),
timeout=None,
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_grpc/test_grpc_index_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ def setup_method(self):
)

def test_fetch_byIds_fetchByIds(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.fetch(["vec1", "vec2"])
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"]), timeout=None
)

def test_fetch_byIdsAndNS_fetchByIdsAndNS(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.fetch(["vec1", "vec2"], namespace="ns", timeout=30)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Fetch, FetchRequest(ids=["vec1", "vec2"], namespace="ns"), timeout=30
)
19 changes: 0 additions & 19 deletions tests/unit_grpc/test_grpc_index_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,6 @@ def test_init_with_default_config(self):
assert index.grpc_client_config.grpc_channel_options is None
assert index.grpc_client_config.additional_metadata is None

# Default metadata, grpc equivalent to http request headers
assert len(index.fixed_metadata) == 3
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
assert index.fixed_metadata["service-name"] == "my-index"
assert index.fixed_metadata["client-version"] is not None

def test_init_with_additional_metadata(self):
pc = PineconeGRPC(api_key="YOUR_API_KEY")
config = GRPCClientConfig(
additional_metadata={"debug-header": "value123", "debug-header2": "value456"}
)
index = pc.Index(name="my-index", host="host", grpc_config=config)
assert len(index.fixed_metadata) == 5
assert index.fixed_metadata["api-key"] == "YOUR_API_KEY"
assert index.fixed_metadata["service-name"] == "my-index"
assert index.fixed_metadata["client-version"] is not None
assert index.fixed_metadata["debug-header"] == "value123"
assert index.fixed_metadata["debug-header2"] == "value456"

def test_init_with_grpc_config_from_dict(self):
pc = PineconeGRPC(api_key="YOUR_API_KEY")
config = GRPCClientConfig._from_dict({"timeout": 10})
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_grpc/test_grpc_index_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def setup_method(self):
)

def test_query_byVectorNoFilter_queryVectorNoFilter(self, mocker, vals1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, vector=vals1)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query, QueryRequest(top_k=10, vector=vals1), timeout=None
)

def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, filter1):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, vector=vals1, filter=filter1, namespace="ns", timeout=10)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query,
QueryRequest(
top_k=10, vector=vals1, filter=dict_to_proto_struct(filter1), namespace="ns"
Expand All @@ -32,9 +32,9 @@ def test_query_byVectorWithFilter_queryVectorWithFilter(self, mocker, vals1, fil
)

def test_query_byVecId_queryByVecId(self, mocker):
mocker.patch.object(self.index, "_wrap_grpc_call", autospec=True)
mocker.patch.object(self.index.runner, "run", autospec=True)
self.index.query(top_k=10, id="vec1", include_metadata=True, include_values=False)
self.index._wrap_grpc_call.assert_called_once_with(
self.index.runner.run.assert_called_once_with(
self.index.stub.Query,
QueryRequest(top_k=10, id="vec1", include_metadata=True, include_values=False),
timeout=None,
Expand Down
Loading

0 comments on commit 2831a5e

Please sign in to comment.