diff --git a/pinecone/config/config.py b/pinecone/config/config.py index 9f622429..a73ae87b 100644 --- a/pinecone/config/config.py +++ b/pinecone/config/config.py @@ -1,5 +1,6 @@ from typing import NamedTuple, Optional, Dict import os +import copy from pinecone.exceptions import PineconeConfigurationError from pinecone.config.openapi import OpenApiConfigFactory @@ -46,10 +47,11 @@ def build( if not host: raise PineconeConfigurationError("You haven't specified a host.") - openapi_config = ( - openapi_config - or kwargs.pop("openapi_config", None) - or OpenApiConfigFactory.build(api_key=api_key, host=host) - ) + if openapi_config: + openapi_config = copy.deepcopy(openapi_config) + openapi_config.host = host + openapi_config.api_key = {"ApiKeyAuth": api_key} + else: + openapi_config = OpenApiConfigFactory.build(api_key=api_key, host=host) return Config(api_key, host, openapi_config, additional_headers) \ No newline at end of file diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index 00a37ac8..ca8ebba4 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -6,8 +6,7 @@ from pinecone.config import PineconeConfig, Config from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi -from pinecone.core.client.api_client import ApiClient -from pinecone.utils import get_user_agent, normalize_host +from pinecone.utils import normalize_host, setup_openapi_client from pinecone.core.client.models import ( CreateCollectionRequest, CreateIndexRequest, @@ -85,25 +84,20 @@ def __init__( or share with Pinecone support. **Be very careful with this option, as it will print out your API key** which forms part of a required authentication header. Default: `false` """ - if config or kwargs.get("config"): - configKwarg = config or kwargs.get("config") - if not isinstance(configKwarg, Config): + if config: + if not isinstance(config, Config): raise TypeError("config must be of type pinecone.config.Config") else: - self.config = configKwarg + self.config = config else: self.config = PineconeConfig.build(api_key=api_key, host=host, additional_headers=additional_headers, **kwargs) self.pool_threads = pool_threads + if index_api: self.index_api = index_api else: - api_client = ApiClient(configuration=self.config.openapi_config, pool_threads=self.pool_threads) - api_client.user_agent = get_user_agent() - extra_headers = self.config.additional_headers or {} - for key, value in extra_headers.items(): - api_client.set_default_header(key, value) - self.index_api = ManageIndexesApi(api_client) + self.index_api = setup_openapi_client(ManageIndexesApi, self.config, pool_threads) self.index_host_store = IndexHostStore() """ @private """ @@ -521,12 +515,20 @@ def Index(self, name: str = '', host: str = '', **kwargs): raise ValueError("Either name or host must be specified") pt = kwargs.pop('pool_threads', None) or self.pool_threads + api_key = self.config.api_key + openapi_config = self.config.openapi_config if host != '': # Use host url if it is provided - return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=pt, **kwargs) - - if name != '': + index_host=normalize_host(host) + else: # Otherwise, get host url from describe_index using the index name index_host = self.index_host_store.get_host(self.index_api, self.config, name) - return Index(api_key=self.config.api_key, host=index_host, pool_threads=pt, **kwargs) + + return Index( + host=index_host, + api_key=api_key, + pool_threads=pt, + openapi_config=openapi_config, + **kwargs + ) \ No newline at end of file diff --git a/pinecone/data/index.py b/pinecone/data/index.py index 6befe5eb..87915e92 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -24,7 +24,7 @@ ListResponse ) from pinecone.core.client.api.data_plane_api import DataPlaneApi -from ..utils import get_user_agent +from ..utils import setup_openapi_client from .vector_factory import VectorFactory __all__ = [ @@ -75,27 +75,23 @@ def __init__( host: str, pool_threads: Optional[int] = 1, additional_headers: Optional[Dict[str, str]] = {}, + openapi_config = None, **kwargs ): - self._config = ConfigBuilder.build(api_key=api_key, host=host, **kwargs) - - api_client = ApiClient(configuration=self._config.openapi_config, - pool_threads=pool_threads) - - # Configure request headers - api_client.user_agent = get_user_agent() - extra_headers = additional_headers or {} - for key, value in extra_headers.items(): - api_client.set_default_header(key, value) - - self._api_client = api_client - self._vector_api = DataPlaneApi(api_client=api_client) + self._config = ConfigBuilder.build( + api_key=api_key, + host=host, + additional_headers=additional_headers, + openapi_config=openapi_config, + **kwargs + ) + self._vector_api = setup_openapi_client(DataPlaneApi, self._config, pool_threads) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self._api_client.close() + self._vector_api.api_client.close() @validate_and_convert_errors def upsert( diff --git a/pinecone/utils/__init__.py b/pinecone/utils/__init__.py index e4ce28b6..27846fca 100644 --- a/pinecone/utils/__init__.py +++ b/pinecone/utils/__init__.py @@ -4,4 +4,5 @@ from .deprecation_notice import warn_deprecated from .fix_tuple_length import fix_tuple_length from .convert_to_list import convert_to_list -from .normalize_host import normalize_host \ No newline at end of file +from .normalize_host import normalize_host +from .setup_openapi_client import setup_openapi_client \ No newline at end of file diff --git a/pinecone/utils/setup_openapi_client.py b/pinecone/utils/setup_openapi_client.py new file mode 100644 index 00000000..2bc9c61f --- /dev/null +++ b/pinecone/utils/setup_openapi_client.py @@ -0,0 +1,14 @@ +from pinecone.core.client.api_client import ApiClient +from .user_agent import get_user_agent + +def setup_openapi_client(api_klass, config, pool_threads): + api_client = ApiClient( + configuration=config.openapi_config, + pool_threads=pool_threads + ) + api_client.user_agent = get_user_agent() + extra_headers = config.additional_headers or {} + for key, value in extra_headers.items(): + api_client.set_default_header(key, value) + client = api_klass(api_client) + return client diff --git a/tests/integration/data/conftest.py b/tests/integration/data/conftest.py index 90753b20..16b3e13b 100644 --- a/tests/integration/data/conftest.py +++ b/tests/integration/data/conftest.py @@ -27,6 +27,10 @@ def build_client(): from pinecone import Pinecone return Pinecone(api_key=api_key(), additional_headers={'sdk-test-suite': 'pinecone-python-client'}) +@pytest.fixture(scope='session') +def api_key_fixture(): + return api_key() + @pytest.fixture(scope='session') def client(): return build_client() diff --git a/tests/integration/data/test_openapi_configuration.py b/tests/integration/data/test_openapi_configuration.py new file mode 100644 index 00000000..e8b93389 --- /dev/null +++ b/tests/integration/data/test_openapi_configuration.py @@ -0,0 +1,18 @@ +import pytest +import os + +from pinecone import Pinecone +from pinecone.core.client.configuration import Configuration as OpenApiConfiguration +from urllib3 import make_headers + +@pytest.mark.skipif(os.getenv('USE_GRPC') != 'false', reason='Only test when using REST') +class TestIndexOpenapiConfig: + def test_passing_openapi_config(self, api_key_fixture, index_host): + oai_config = OpenApiConfiguration.get_default_copy() + p = Pinecone(api_key=api_key_fixture, openapi_config=oai_config) + assert p.config.api_key == api_key_fixture + p.list_indexes() # should not throw + + index = p.Index(host=index_host) + assert index._config.api_key == api_key_fixture + index.describe_index_stats() \ No newline at end of file diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e8917e50..f1d6af7a 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -6,6 +6,8 @@ import pytest import os +from urllib3 import make_headers + class TestConfig: @pytest.fixture(autouse=True) def run_before_and_after_tests(tmpdir): @@ -49,13 +51,14 @@ def test_init_with_positional_args(self): def test_init_with_kwargs(self): api_key = "my-api-key" controller_host = "my-controller-host" - openapi_config = OpenApiConfiguration(api_key="openapi-api-key") + openapi_config = OpenApiConfiguration() + openapi_config.ssl_ca_cert = 'path/to/cert' config = PineconeConfig.build(api_key=api_key, host=controller_host, openapi_config=openapi_config) assert config.api_key == api_key assert config.host == 'https://' + controller_host - assert config.openapi_config == openapi_config + assert config.openapi_config.ssl_ca_cert == 'path/to/cert' def test_resolution_order_kwargs_over_env_vars(self): """ @@ -84,5 +87,43 @@ def test_config_pool_threads(self): pc = Pinecone(api_key="test-api-key", host="test-controller-host", pool_threads=10) assert pc.index_api.api_client.pool_threads == 10 idx = pc.Index(host='my-index-host', name='my-index-name') - assert idx._api_client.pool_threads == 10 + assert idx._vector_api.api_client.pool_threads == 10 + + def test_config_when_openapi_config_is_passed_merges_api_key(self): + oai_config = OpenApiConfiguration() + pc = Pinecone(api_key='asdf', openapi_config=oai_config) + assert pc.config.openapi_config.api_key == {'ApiKeyAuth': 'asdf'} + + def test_ssl_config_passed_to_index_client(self): + oai_config = OpenApiConfiguration() + oai_config.ssl_ca_cert = 'path/to/cert' + proxy_headers = make_headers(proxy_basic_auth='asdf') + oai_config.proxy_headers = proxy_headers + pc = Pinecone(api_key='key', openapi_config=oai_config) + + assert pc.config.openapi_config.ssl_ca_cert == 'path/to/cert' + assert pc.config.openapi_config.proxy_headers == proxy_headers + + idx = pc.Index(host='host') + assert idx._vector_api.api_client.configuration.ssl_ca_cert == 'path/to/cert' + assert idx._vector_api.api_client.configuration.proxy_headers == proxy_headers + + def test_host_config_not_clobbered_by_index(self): + oai_config = OpenApiConfiguration() + oai_config.ssl_ca_cert = 'path/to/cert' + proxy_headers = make_headers(proxy_basic_auth='asdf') + oai_config.proxy_headers = proxy_headers + + pc = Pinecone(api_key='key', openapi_config=oai_config) + + assert pc.config.openapi_config.ssl_ca_cert == 'path/to/cert' + assert pc.config.openapi_config.proxy_headers == proxy_headers + assert pc.config.openapi_config.host == 'https://api.pinecone.io' + + idx = pc.Index(host='host') + assert idx._vector_api.api_client.configuration.ssl_ca_cert == 'path/to/cert' + assert idx._vector_api.api_client.configuration.proxy_headers == proxy_headers + assert idx._vector_api.api_client.configuration.host == 'https://host' + + assert pc.config.openapi_config.host == 'https://api.pinecone.io' \ No newline at end of file diff --git a/tests/unit/test_config_builder.py b/tests/unit/test_config_builder.py new file mode 100644 index 00000000..7f4d63dc --- /dev/null +++ b/tests/unit/test_config_builder.py @@ -0,0 +1,36 @@ +import pytest + +from pinecone.core.client.configuration import Configuration as OpenApiConfiguration +from pinecone.config import ConfigBuilder +from pinecone import PineconeConfigurationError + +class TestConfigBuilder: + def test_build_simple(self): + config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host") + assert config.api_key == "my-api-key" + assert config.host == "https://my-controller-host" + assert config.additional_headers == {} + assert config.openapi_config.host == "https://my-controller-host" + assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"} + + def test_build_merges_key_and_host_when_openapi_config_provided(self): + config = ConfigBuilder.build( + api_key="my-api-key", + host="https://my-controller-host", + openapi_config=OpenApiConfiguration() + ) + assert config.api_key == "my-api-key" + assert config.host == "https://my-controller-host" + assert config.additional_headers == {} + assert config.openapi_config.host == "https://my-controller-host" + assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"} + + def test_build_errors_when_no_api_key_is_present(self): + with pytest.raises(PineconeConfigurationError) as e: + ConfigBuilder.build() + assert str(e.value) == "You haven't specified an Api-Key." + + def test_build_errors_when_no_host_is_present(self): + with pytest.raises(PineconeConfigurationError) as e: + ConfigBuilder.build(api_key='my-api-key') + assert str(e.value) == "You haven't specified a host." \ No newline at end of file diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py index 18ccae9e..d058e2c6 100644 --- a/tests/unit/test_control.py +++ b/tests/unit/test_control.py @@ -2,6 +2,8 @@ from pinecone import Pinecone, PodSpec, ServerlessSpec from pinecone.core.client.models import IndexList, IndexModel from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi +from pinecone.core.client.configuration import Configuration as OpenApiConfiguration + import time @pytest.fixture @@ -107,25 +109,29 @@ def test_list_indexes_returns_iterable(self, mocker, index_list_response): response = p.list_indexes() assert [i.name for i in response] == ["index1", "index2", "index3"] + def test_api_key_and_openapi_config(self, mocker): + p = Pinecone(api_key="123", openapi_config=OpenApiConfiguration.get_default_copy()) + assert p.config.api_key == "123" class TestIndexConfig: def test_default_pool_threads(self): pc = Pinecone(api_key="123-456-789") index = pc.Index(host='my-host.svg.pinecone.io') - assert index._api_client.pool_threads == 1 + assert index._vector_api.api_client.pool_threads == 1 def test_pool_threads_when_indexapi_passed(self): pc = Pinecone(api_key="123-456-789", pool_threads=2, index_api=ManageIndexesApi()) index = pc.Index(host='my-host.svg.pinecone.io') - assert index._api_client.pool_threads == 2 + assert index._vector_api.api_client.pool_threads == 2 def test_target_index_with_pool_threads_inherited(self): pc = Pinecone(api_key="123-456-789", pool_threads=10, foo='bar') index = pc.Index(host='my-host.svg.pinecone.io') - assert index._api_client.pool_threads == 10 + assert index._vector_api.api_client.pool_threads == 10 def test_target_index_with_pool_threads_kwarg(self): pc = Pinecone(api_key="123-456-789", pool_threads=10) index = pc.Index(host='my-host.svg.pinecone.io', pool_threads=5) - assert index._api_client.pool_threads == 5 + assert index._vector_api.api_client.pool_threads == 5 + diff --git a/tests/unit/test_index_initialization.py b/tests/unit/test_index_initialization.py index 0dc2545d..e0c4a30a 100644 --- a/tests/unit/test_index_initialization.py +++ b/tests/unit/test_index_initialization.py @@ -12,9 +12,9 @@ class TestIndexClientInitialization(): def test_no_additional_headers_leaves_useragent_only(self, additional_headers): pc = Pinecone(api_key='YOUR_API_KEY') index = pc.Index(host='myhost', additional_headers=additional_headers) - assert len(index._api_client.default_headers) == 1 - assert 'User-Agent' in index._api_client.default_headers - assert 'python-client-' in index._api_client.default_headers['User-Agent'] + assert len(index._vector_api.api_client.default_headers) == 1 + assert 'User-Agent' in index._vector_api.api_client.default_headers + assert 'python-client-' in index._vector_api.api_client.default_headers['User-Agent'] def test_additional_headers_one_additional(self): pc = Pinecone(api_key='YOUR_API_KEY') @@ -22,8 +22,8 @@ def test_additional_headers_one_additional(self): host='myhost', additional_headers={'test-header': 'test-header-value'} ) - assert 'test-header' in index._api_client.default_headers - assert len(index._api_client.default_headers) == 2 + assert 'test-header' in index._vector_api.api_client.default_headers + assert len(index._vector_api.api_client.default_headers) == 2 def test_multiple_additional_headers(self): pc = Pinecone(api_key='YOUR_API_KEY') @@ -34,9 +34,9 @@ def test_multiple_additional_headers(self): 'test-header2': 'test-header-value2' } ) - assert 'test-header' in index._api_client.default_headers - assert 'test-header2' in index._api_client.default_headers - assert len(index._api_client.default_headers) == 3 + assert 'test-header' in index._vector_api.api_client.default_headers + assert 'test-header2' in index._vector_api.api_client.default_headers + assert len(index._vector_api.api_client.default_headers) == 3 def test_overwrite_useragent(self): # This doesn't seem like a common use case, but we may want to allow this @@ -48,6 +48,6 @@ def test_overwrite_useragent(self): 'User-Agent': 'test-user-agent' } ) - assert len(index._api_client.default_headers) == 1 - assert 'User-Agent' in index._api_client.default_headers - assert index._api_client.default_headers['User-Agent'] == 'test-user-agent' \ No newline at end of file + assert len(index._vector_api.api_client.default_headers) == 1 + assert 'User-Agent' in index._vector_api.api_client.default_headers + assert index._vector_api.api_client.default_headers['User-Agent'] == 'test-user-agent' \ No newline at end of file