Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Extract GrpcRunner from GRPCIndexBase class #395

Merged
merged 4 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 4 additions & 57 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import logging
from abc import ABC, abstractmethod
from functools import wraps
from typing import Dict, Optional
from typing import Optional

import grpc
from grpc._channel import _InactiveRpcError, Channel
from grpc._channel import Channel

from .retry import RetryConfig
from .channel_factory import GrpcChannelFactory

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException

_logger = logging.getLogger(__name__)
from .grpc_runner import GrpcRunner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't being used, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently unused. The _logger reference is now used in the GrpcChannelFactory which is logic I recently pulled out of this base class. Cleaning up this reference should have happened in that PR, but got overlooked.



class GRPCIndexBase(ABC):
Expand All @@ -35,18 +28,10 @@ def __init__(
):
self.config = config
self.grpc_client_config = grpc_config or GRPCClientConfig()
self.retry_config = self.grpc_client_config.retry_config or RetryConfig()

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

self._endpoint_override = _endpoint_override

self.runner = GrpcRunner(index_name=index_name, config=config, grpc_config=grpc_config)
self.channel_factory = GrpcChannelFactory(
config=self.config, grpc_client_config=self.grpc_client_config, use_asyncio=False
)
Expand Down Expand Up @@ -91,44 +76,6 @@ def close(self):
except TypeError:
pass

def _wrap_grpc_call(
self,
func,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}

def __enter__(self):
return self

Expand Down
93 changes: 93 additions & 0 deletions pinecone/grpc/grpc_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from functools import wraps
from typing import Dict, Tuple

from grpc._channel import _InactiveRpcError

from pinecone import Config
from .utils import _generate_request_id
from .config import GRPCClientConfig
from pinecone.utils.constants import REQUEST_ID, CLIENT_VERSION
from pinecone.exceptions.exceptions import PineconeException


class GrpcRunner:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea pulling all of this out into somewhere contained. 👍

def __init__(self, index_name: str, config: Config, grpc_config: GRPCClientConfig):
self.config = config
self.grpc_client_config = grpc_config

self.fixed_metadata = {
"api-key": config.api_key,
"service-name": index_name,
"client-version": CLIENT_VERSION,
}
if self.grpc_client_config.additional_metadata:
self.fixed_metadata.update(self.grpc_client_config.additional_metadata)

def run(
self,
func,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
@wraps(func)
def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return wrapped()

async def run_asyncio(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not using or testing run_asyncio anywhere yet, but I will. It's the same as run but with the addition of the async/await bits.

self,
func,
request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None,
):
@wraps(func)
async def wrapped():
user_provided_metadata = metadata or {}
_metadata = self._prepare_metadata(user_provided_metadata)
try:
return await func(
request,
timeout=timeout,
metadata=_metadata,
credentials=credentials,
wait_for_ready=wait_for_ready,
compression=compression,
)
except _InactiveRpcError as e:
raise PineconeException(e._state.debug_error_string) from e

return await wrapped()

def _prepare_metadata(self, user_provided_metadata: Dict[str, str]) -> Tuple[Tuple[str, str]]:
return tuple(
(k, v)
for k, v in {
**self.fixed_metadata,
**self._request_metadata(),
**user_provided_metadata,
}.items()
)

def _request_metadata(self) -> Dict[str, str]:
return {REQUEST_ID: _generate_request_id()}
20 changes: 10 additions & 10 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def upsert(
if async_req:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
future = self._wrap_grpc_call(self.stub.Upsert.future, request, timeout=timeout)
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
return PineconeGrpcFuture(future)

if batch_size is None:
Expand Down Expand Up @@ -163,7 +163,7 @@ def _upsert_batch(
) -> UpsertResponse:
args_dict = self._parse_non_empty_args([("namespace", namespace)])
request = UpsertRequest(vectors=vectors, **args_dict)
return self._wrap_grpc_call(self.stub.Upsert, request, timeout=timeout, **kwargs)
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)

def upsert_from_dataframe(
self,
Expand Down Expand Up @@ -280,10 +280,10 @@ def delete(

request = DeleteRequest(**args_dict, **kwargs)
if async_req:
future = self._wrap_grpc_call(self.stub.Delete.future, request, timeout=timeout)
future = self.runner.run(self.stub.Delete.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Delete, request, timeout=timeout)
return self.runner.run(self.stub.Delete, request, timeout=timeout)

def fetch(
self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs
Expand All @@ -308,7 +308,7 @@ def fetch(
args_dict = self._parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self._wrap_grpc_call(self.stub.Fetch, request, timeout=timeout)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_fetch_response(json_response)

Expand Down Expand Up @@ -388,7 +388,7 @@ def query(
request = QueryRequest(**args_dict)

timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.Query, request, timeout=timeout)
response = self.runner.run(self.stub.Query, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_query_response(json_response, _check_type=False)

Expand Down Expand Up @@ -451,10 +451,10 @@ def update(

request = UpdateRequest(id=id, **args_dict)
if async_req:
future = self._wrap_grpc_call(self.stub.Update.future, request, timeout=timeout)
future = self.runner.run(self.stub.Update.future, request, timeout=timeout)
return PineconeGrpcFuture(future)
else:
return self._wrap_grpc_call(self.stub.Update, request, timeout=timeout)
return self.runner.run(self.stub.Update, request, timeout=timeout)

def list_paginated(
self,
Expand Down Expand Up @@ -499,7 +499,7 @@ def list_paginated(
)
request = ListRequest(**args_dict, **kwargs)
timeout = kwargs.pop("timeout", None)
response = self._wrap_grpc_call(self.stub.List, request, timeout=timeout)
response = self.runner.run(self.stub.List, request, timeout=timeout)

if response.pagination and response.pagination.next != "":
pagination = Pagination(next=response.pagination.next)
Expand Down Expand Up @@ -572,7 +572,7 @@ def describe_index_stats(
timeout = kwargs.pop("timeout", None)

request = DescribeIndexStatsRequest(**args_dict)
response = self._wrap_grpc_call(self.stub.DescribeIndexStats, request, timeout=timeout)
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_stats_response(json_response)

Expand Down
Loading