-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor] Extract GrpcChannelFactory from GRPCIndexBase (#394)
## Problem I'm preparing to implement asyncio for the data plane, and I had a need to extract some of this grpc channel configuration into a spot where it could be reused more easily across both sync and async implementations. ## Solution - Extract `GrpcChannelFactory` from `GRPCIndexBase` - Add some unit tests for this new class ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [x] None of the above: Refactoring only, should be no functional change
- Loading branch information
Showing
3 changed files
with
248 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import logging | ||
from typing import Optional | ||
|
||
import certifi | ||
import grpc | ||
import json | ||
|
||
from pinecone import Config | ||
from .config import GRPCClientConfig | ||
from pinecone.utils.constants import MAX_MSG_SIZE | ||
from pinecone.utils.user_agent import get_user_agent_grpc | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class GrpcChannelFactory: | ||
def __init__( | ||
self, | ||
config: Config, | ||
grpc_client_config: GRPCClientConfig, | ||
use_asyncio: Optional[bool] = False, | ||
): | ||
self.config = config | ||
self.grpc_client_config = grpc_client_config | ||
self.use_asyncio = use_asyncio | ||
|
||
def _get_service_config(self): | ||
# https://github.com/grpc/grpc-proto/blob/master/grpc/service_config/service_config.proto | ||
return json.dumps( | ||
{ | ||
"methodConfig": [ | ||
{ | ||
"name": [{"service": "VectorService.Upsert"}], | ||
"retryPolicy": { | ||
"maxAttempts": 5, | ||
"initialBackoff": "0.1s", | ||
"maxBackoff": "1s", | ||
"backoffMultiplier": 2, | ||
"retryableStatusCodes": ["UNAVAILABLE"], | ||
}, | ||
}, | ||
{ | ||
"name": [{"service": "VectorService"}], | ||
"retryPolicy": { | ||
"maxAttempts": 5, | ||
"initialBackoff": "0.1s", | ||
"maxBackoff": "1s", | ||
"backoffMultiplier": 2, | ||
"retryableStatusCodes": ["UNAVAILABLE"], | ||
}, | ||
}, | ||
] | ||
} | ||
) | ||
|
||
def _build_options(self, target): | ||
# For property definitions, see https://github.com/grpc/grpc/blob/v1.43.x/include/grpc/impl/codegen/grpc_types.h | ||
options = { | ||
"grpc.max_send_message_length": MAX_MSG_SIZE, | ||
"grpc.max_receive_message_length": MAX_MSG_SIZE, | ||
"grpc.service_config": self._get_service_config(), | ||
"grpc.enable_retries": True, | ||
"grpc.per_rpc_retry_buffer_size": MAX_MSG_SIZE, | ||
"grpc.primary_user_agent": get_user_agent_grpc(self.config), | ||
} | ||
if self.grpc_client_config.secure: | ||
options["grpc.ssl_target_name_override"] = target.split(":")[0] | ||
if self.config.proxy_url: | ||
options["grpc.http_proxy"] = self.config.proxy_url | ||
|
||
options_tuple = tuple((k, v) for k, v in options.items()) | ||
return options_tuple | ||
|
||
def _build_channel_credentials(self): | ||
ca_certs = self.config.ssl_ca_certs if self.config.ssl_ca_certs else certifi.where() | ||
root_cas = open(ca_certs, "rb").read() | ||
channel_creds = grpc.ssl_channel_credentials(root_certificates=root_cas) | ||
return channel_creds | ||
|
||
def create_channel(self, endpoint): | ||
options_tuple = self._build_options(endpoint) | ||
|
||
_logger.debug( | ||
"Creating new channel with endpoint %s options %s and config %s", | ||
endpoint, | ||
options_tuple, | ||
self.grpc_client_config, | ||
) | ||
|
||
if not self.grpc_client_config.secure: | ||
create_channel_fn = ( | ||
grpc.aio.insecure_channel if self.use_asyncio else grpc.insecure_channel | ||
) | ||
channel = create_channel_fn(endpoint, options=options_tuple) | ||
else: | ||
channel_creds = self._build_channel_credentials() | ||
create_channel_fn = grpc.aio.secure_channel if self.use_asyncio else grpc.secure_channel | ||
channel = create_channel_fn(endpoint, credentials=channel_creds, options=options_tuple) | ||
|
||
return channel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import grpc | ||
import re | ||
import pytest | ||
from unittest.mock import patch, MagicMock, ANY | ||
|
||
from pinecone import Config | ||
from pinecone.grpc.channel_factory import GrpcChannelFactory, GRPCClientConfig | ||
from pinecone.utils.constants import MAX_MSG_SIZE | ||
|
||
|
||
@pytest.fixture | ||
def config(): | ||
return Config(ssl_ca_certs=None, proxy_url=None) | ||
|
||
|
||
@pytest.fixture | ||
def grpc_client_config(): | ||
return GRPCClientConfig(secure=True) | ||
|
||
|
||
class TestGrpcChannelFactory: | ||
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config): | ||
factory = GrpcChannelFactory( | ||
config=config, grpc_client_config=grpc_client_config, use_asyncio=False | ||
) | ||
endpoint = "test.endpoint:443" | ||
|
||
with patch("grpc.secure_channel") as mock_secure_channel, patch( | ||
"certifi.where", return_value="/path/to/certifi/cacert.pem" | ||
), patch("builtins.open", new_callable=MagicMock) as mock_open: | ||
# Mock the file object to return bytes when read() is called | ||
mock_file = MagicMock() | ||
mock_file.read.return_value = b"mocked_cert_data" | ||
mock_open.return_value = mock_file | ||
channel = factory.create_channel(endpoint) | ||
|
||
mock_secure_channel.assert_called_once() | ||
assert mock_secure_channel.call_args[0][0] == endpoint | ||
assert isinstance(mock_secure_channel.call_args[1]["options"], tuple) | ||
|
||
options = dict(mock_secure_channel.call_args[1]["options"]) | ||
assert options["grpc.ssl_target_name_override"] == "test.endpoint" | ||
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE | ||
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE | ||
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE | ||
assert "grpc.service_config" in options | ||
assert options["grpc.enable_retries"] is True | ||
assert ( | ||
re.search( | ||
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"] | ||
) | ||
is not None | ||
) | ||
|
||
assert isinstance(channel, MagicMock) | ||
|
||
def test_create_secure_channel_with_proxy(self): | ||
grpc_client_config = GRPCClientConfig(secure=True) | ||
config = Config(proxy_url="http://test.proxy:8080") | ||
factory = GrpcChannelFactory( | ||
config=config, grpc_client_config=grpc_client_config, use_asyncio=False | ||
) | ||
endpoint = "test.endpoint:443" | ||
|
||
with patch("grpc.secure_channel") as mock_secure_channel: | ||
channel = factory.create_channel(endpoint) | ||
|
||
mock_secure_channel.assert_called_once() | ||
assert "grpc.http_proxy" in dict(mock_secure_channel.call_args[1]["options"]) | ||
assert ( | ||
"http://test.proxy:8080" | ||
== dict(mock_secure_channel.call_args[1]["options"])["grpc.http_proxy"] | ||
) | ||
assert isinstance(channel, MagicMock) | ||
|
||
def test_create_insecure_channel(self, config): | ||
grpc_client_config = GRPCClientConfig(secure=False) | ||
factory = GrpcChannelFactory( | ||
config=config, grpc_client_config=grpc_client_config, use_asyncio=False | ||
) | ||
endpoint = "test.endpoint:50051" | ||
|
||
with patch("grpc.insecure_channel") as mock_insecure_channel: | ||
channel = factory.create_channel(endpoint) | ||
|
||
mock_insecure_channel.assert_called_once_with(endpoint, options=ANY) | ||
assert isinstance(channel, MagicMock) | ||
|
||
|
||
class TestGrpcChannelFactoryAsyncio: | ||
def test_create_secure_channel_with_default_settings(self, config, grpc_client_config): | ||
factory = GrpcChannelFactory( | ||
config=config, grpc_client_config=grpc_client_config, use_asyncio=True | ||
) | ||
endpoint = "test.endpoint:443" | ||
|
||
with patch("grpc.aio.secure_channel") as mock_secure_aio_channel, patch( | ||
"certifi.where", return_value="/path/to/certifi/cacert.pem" | ||
), patch("builtins.open", new_callable=MagicMock) as mock_open: | ||
# Mock the file object to return bytes when read() is called | ||
mock_file = MagicMock() | ||
mock_file.read.return_value = b"mocked_cert_data" | ||
mock_open.return_value = mock_file | ||
channel = factory.create_channel(endpoint) | ||
|
||
mock_secure_aio_channel.assert_called_once() | ||
assert mock_secure_aio_channel.call_args[0][0] == endpoint | ||
assert isinstance(mock_secure_aio_channel.call_args[1]["options"], tuple) | ||
|
||
options = dict(mock_secure_aio_channel.call_args[1]["options"]) | ||
assert options["grpc.ssl_target_name_override"] == "test.endpoint" | ||
assert options["grpc.max_send_message_length"] == MAX_MSG_SIZE | ||
assert options["grpc.per_rpc_retry_buffer_size"] == MAX_MSG_SIZE | ||
assert options["grpc.max_receive_message_length"] == MAX_MSG_SIZE | ||
assert "grpc.service_config" in options | ||
assert options["grpc.enable_retries"] is True | ||
assert ( | ||
re.search( | ||
r"python-client\[grpc\]-\d+\.\d+\.\d+", options["grpc.primary_user_agent"] | ||
) | ||
is not None | ||
) | ||
|
||
security_credentials = mock_secure_aio_channel.call_args[1]["credentials"] | ||
assert security_credentials is not None | ||
assert isinstance(security_credentials, grpc.ChannelCredentials) | ||
|
||
assert isinstance(channel, MagicMock) | ||
|
||
def test_create_insecure_channel_asyncio(self, config): | ||
grpc_client_config = GRPCClientConfig(secure=False) | ||
factory = GrpcChannelFactory( | ||
config=config, grpc_client_config=grpc_client_config, use_asyncio=True | ||
) | ||
endpoint = "test.endpoint:50051" | ||
|
||
with patch("grpc.aio.insecure_channel") as mock_aio_insecure_channel: | ||
channel = factory.create_channel(endpoint) | ||
|
||
mock_aio_insecure_channel.assert_called_once_with(endpoint, options=ANY) | ||
assert isinstance(channel, MagicMock) |