From db3535497f45e62091834ad0b5e625c0081ba198 Mon Sep 17 00:00:00 2001 From: TheDude Date: Thu, 26 Dec 2024 14:28:22 +0530 Subject: [PATCH] Add secrets reload withs safe retry --- plugins/openai/superduper_openai/model.py | 11 +++++-- superduper/base/config.py | 1 + superduper/base/exceptions.py | 8 +++++ superduper/misc/files.py | 19 ++++++++++++ superduper/misc/retry.py | 38 +++++++++++++++++++++++ 5 files changed, 74 insertions(+), 3 deletions(-) diff --git a/plugins/openai/superduper_openai/model.py b/plugins/openai/superduper_openai/model.py index 73565e3a0..b374bfc8f 100644 --- a/plugins/openai/superduper_openai/model.py +++ b/plugins/openai/superduper_openai/model.py @@ -19,7 +19,8 @@ from superduper.base.datalayer import Datalayer from superduper.components.model import APIBaseModel, Inputs from superduper.misc.compat import cache -from superduper.misc.retry import Retry +from superduper.misc.retry import Retry, safe_retry +from superduper.base import exceptions retry = Retry( exception_types=( @@ -61,19 +62,23 @@ def __post_init__(self, db, example): self.client_kwargs['base_url'] = self.openai_api_base self.client_kwargs['default_headers'] = self.openai_api_base + @safe_retry(exceptions.SecretsMissingException) + def init(self, db=None): + """Initialize the model.""" + super().init() + # dall-e is not currently included in list returned by OpenAI model endpoint if self.model not in ( mo := _available_models(json.dumps(self.client_kwargs)) ) and self.model not in ('dall-e'): msg = f'model {self.model} not in OpenAI available models, {mo}' raise ValueError(msg) - self.syncClient = SyncOpenAI(**self.client_kwargs) if 'OPENAI_API_KEY' not in os.environ and ( 'api_key' not in self.client_kwargs.keys() and self.client_kwargs ): - raise ValueError( + raise exceptions.SecretsMissingException( 'OPENAI_API_KEY not available neither in environment vars ' 'nor in `client_kwargs`' ) diff --git a/superduper/base/config.py b/superduper/base/config.py index 7d39b748e..351c41d11 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -174,6 +174,7 @@ class Config(BaseConfig): envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None data_backend: str = "mongodb://localhost:27017/test_db" + secret_dir: str = os.path.join(".superduper", "secrets") lance_home: str = os.path.join(".superduper", "vector_indices") diff --git a/superduper/base/exceptions.py b/superduper/base/exceptions.py index eeb0225ca..3a2b73958 100644 --- a/superduper/base/exceptions.py +++ b/superduper/base/exceptions.py @@ -89,3 +89,11 @@ class UnsupportedDatatype(BaseException): :param msg: msg for BaseException """ + + +class SecretsMissingException(BaseException): + """ + Missing secrets. + + :param msg: msg for BaseException + """ diff --git a/superduper/misc/files.py b/superduper/misc/files.py index 4ab053c0a..9b1fc083f 100644 --- a/superduper/misc/files.py +++ b/superduper/misc/files.py @@ -1,8 +1,27 @@ import hashlib +import os from superduper import CFG +def load_secrets(): + secrets_dir = CFG.secrets_dir + if not os.path.isdir(secrets_dir): + raise ValueError(f"The path '{secrets_dir}' is not a valid directory.") + + for root, _, files in os.walk(secrets_dir): + for file_name in files: + file_path = os.path.join(root, file_name) + try: + with open(file_path, 'r') as file: + content = file.read().strip() + + key = file_name + os.environ[key] = content + except Exception as e: + print(f"Error reading file {file_path}: {e}") + + def get_file_from_uri(uri): """ Get file name from uri. diff --git a/superduper/misc/retry.py b/superduper/misc/retry.py index a8762427c..f087a07e7 100644 --- a/superduper/misc/retry.py +++ b/superduper/misc/retry.py @@ -1,3 +1,4 @@ +import time import dataclasses as dc import functools import typing as t @@ -5,6 +6,8 @@ import tenacity import superduper as s +from superduper import logging +from superduper.misc.files import load_secrets ExceptionTypes = t.Union[t.Type[BaseException], t.Tuple[t.Type[BaseException], ...]] @@ -40,6 +43,41 @@ def __call__(self, f: t.Callable) -> t.Any: return retrier(f) +def safe_retry(exception_to_check, retries=1, delay=1): + """ + A decorator that retries a function if a specified exception is raised. + + :param exception_to_check: The exception or tuple of exceptions to check. + :param retries: The maximum number of retries. + :param delay: Delay between retries in seconds. + :return: The result of the decorated function. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + attempt = 0 + while attempt < retries: + try: + load_secrets() + return func(*args, **kwargs) + except exception_to_check as e: + attempt += 1 + if attempt >= retries: + logging.error( + f"Function {func.__name__} failed after {retries} retries." + ) + raise + logging.warn( + f"Retrying {func.__name__} due to {e}, attempt {attempt} of {retries}..." + ) + time.sleep(delay) + + return wrapper + + return decorator + + def db_retry(connector='databackend'): """Helper method to retry methods with database calls.