diff --git a/plugins/openai/pyproject.toml b/plugins/openai/pyproject.toml index 1cc7508c5..ed22a766e 100644 --- a/plugins/openai/pyproject.toml +++ b/plugins/openai/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ [project.optional-dependencies] test = [ - "vcrpy>=5.1.0", + "vcrpy==5.1.0", + "urllib3==2.2.3", ] [project.urls] diff --git a/plugins/openai/superduper_openai/__init__.py b/plugins/openai/superduper_openai/__init__.py index 891e18817..8aa4cf285 100644 --- a/plugins/openai/superduper_openai/__init__.py +++ b/plugins/openai/superduper_openai/__init__.py @@ -1,5 +1,5 @@ from .model import OpenAIChatCompletion, OpenAIEmbedding -__version__ = "0.4.2" +__version__ = "0.4.3" __all__ = 'OpenAIChatCompletion', 'OpenAIEmbedding' diff --git a/plugins/openai/superduper_openai/model.py b/plugins/openai/superduper_openai/model.py index 165f1ab1b..0bf36dab0 100644 --- a/plugins/openai/superduper_openai/model.py +++ b/plugins/openai/superduper_openai/model.py @@ -16,11 +16,11 @@ ) from openai._types import NOT_GIVEN from superduper.backends.query_dataset import QueryDataset +from superduper.base import exceptions from superduper.base.datalayer import Datalayer from superduper.components.model import APIBaseModel, Inputs from superduper.misc.compat import cache from superduper.misc.retry import Retry, safe_retry -from superduper.base import exceptions retry = Retry( exception_types=( @@ -55,27 +55,19 @@ def __post_init__(self, db, example): super().__post_init__(db, example) assert isinstance(self.client_kwargs, dict) - if self.openai_api_key is not None: self.client_kwargs['api_key'] = self.openai_api_key if self.openai_api_base is not None: self.client_kwargs['base_url'] = self.openai_api_base self.client_kwargs['default_headers'] = self.openai_api_base - @safe_retry(exceptions.MissingSecretsException) - def init(self, db=None): + @safe_retry(exceptions.MissingSecretsException, verbose=0) + def init(self): """Initialize the model.""" super().init() # dall-e is not currently included in list returned by OpenAI model endpoint - if self.model not in ( - mo := _available_models(json.dumps(self.client_kwargs)) - ) and self.model not in ('dall-e'): - msg = f'model {self.model} not in OpenAI available models, {mo}' - raise ValueError(msg) - self.syncClient = SyncOpenAI(**self.client_kwargs) - - if 'OPENAI_API_KEY' not in os.environ and ( + if 'OPENAI_API_KEY' not in os.environ or ( 'api_key' not in self.client_kwargs.keys() and self.client_kwargs ): raise exceptions.MissingSecretsException( @@ -83,6 +75,13 @@ def init(self, db=None): 'nor in `client_kwargs`' ) + if self.model not in ( + mo := _available_models(json.dumps(self.client_kwargs)) + ) and self.model not in ('dall-e'): + msg = f'model {self.model} not in OpenAI available models, {mo}' + raise ValueError(msg) + self.syncClient = SyncOpenAI(**self.client_kwargs) + def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: """Predict on a dataset. diff --git a/superduper/base/config.py b/superduper/base/config.py index 35088d282..55e72e19b 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -1,7 +1,7 @@ """Configuration variables for superduper.io. The classes in this file define the configuration variables for superduper.io, -which means that this file gets imported before alost anything else, and +hich means that this file gets imported before alost anything else, and canot contain any other imports from this project. """ @@ -148,6 +148,7 @@ class Config(BaseConfig): :param envs: The envs datas :param data_backend: The URI for the data backend + :param secrets_volume: The secrets volume mount for secrets env vars. :param lance_home: The home directory for the Lance vector indices, Default: .superduper/vector_indices :param artifact_store: The URI for the artifact store diff --git a/superduper/misc/retry.py b/superduper/misc/retry.py index 8e2968ab3..f294242df 100644 --- a/superduper/misc/retry.py +++ b/superduper/misc/retry.py @@ -43,13 +43,14 @@ def __call__(self, f: t.Callable) -> t.Any: return retrier(f) -def safe_retry(exception_to_check, retries=1, delay=1): +def safe_retry(exception_to_check, retries=1, delay=0.3, verbose=1): """ A decorator that retries a function if a specified exception is raised. :param exception_to_check: The exception or tuple of exceptions to check. :param retries: The maximum number of retries. :param delay: Delay between retries in seconds. + :param verbose: Verbose for logs. :return: The result of the decorated function. """ @@ -57,21 +58,25 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): attempt = 0 - while attempt < retries: + while attempt <= retries: try: - load_secrets() + if attempt >= 1: + load_secrets() return func(*args, **kwargs) except exception_to_check as e: attempt += 1 - if attempt >= retries: - logging.error( - f"Function {func.__name__} failed after {retries} retries." - ) + if attempt > retries: + if verbose: + logging.error( + f"Function {func.__name__} failed ", + "after {retries} retries.", + ) raise - logging.warn( - f"Retrying {func.__name__} due to {e}" - ", attempt {attempt} of {retries}..." - ) + if verbose: + logging.warn( + f"Retrying {func.__name__} due to {e}" + ", attempt {attempt} of {retries}..." + ) time.sleep(delay) return wrapper