Skip to content

Commit

Permalink
Add secrets reload withs safe retry
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 26, 2024
1 parent ceeb9bd commit db35354
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
11 changes: 8 additions & 3 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.misc.compat import cache
from superduper.misc.retry import Retry
from superduper.misc.retry import Retry, safe_retry
from superduper.base import exceptions

retry = Retry(
exception_types=(
Expand Down Expand Up @@ -61,19 +62,23 @@ def __post_init__(self, db, example):
self.client_kwargs['base_url'] = self.openai_api_base
self.client_kwargs['default_headers'] = self.openai_api_base

@safe_retry(exceptions.SecretsMissingException)
def init(self, db=None):
"""Initialize the model."""
super().init()

# dall-e is not currently included in list returned by OpenAI model endpoint
if self.model not in (
mo := _available_models(json.dumps(self.client_kwargs))
) and self.model not in ('dall-e'):
msg = f'model {self.model} not in OpenAI available models, {mo}'
raise ValueError(msg)

self.syncClient = SyncOpenAI(**self.client_kwargs)

if 'OPENAI_API_KEY' not in os.environ and (
'api_key' not in self.client_kwargs.keys() and self.client_kwargs
):
raise ValueError(
raise exceptions.SecretsMissingException(
'OPENAI_API_KEY not available neither in environment vars '
'nor in `client_kwargs`'
)
Expand Down
1 change: 1 addition & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class Config(BaseConfig):
envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None

data_backend: str = "mongodb://localhost:27017/test_db"
secret_dir: str = os.path.join(".superduper", "secrets")

lance_home: str = os.path.join(".superduper", "vector_indices")

Expand Down
8 changes: 8 additions & 0 deletions superduper/base/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ class UnsupportedDatatype(BaseException):
:param msg: msg for BaseException
"""


class SecretsMissingException(BaseException):
"""
Missing secrets.
:param msg: msg for BaseException
"""
19 changes: 19 additions & 0 deletions superduper/misc/files.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
import hashlib
import os

from superduper import CFG


def load_secrets():
secrets_dir = CFG.secrets_dir
if not os.path.isdir(secrets_dir):
raise ValueError(f"The path '{secrets_dir}' is not a valid directory.")

for root, _, files in os.walk(secrets_dir):
for file_name in files:
file_path = os.path.join(root, file_name)
try:
with open(file_path, 'r') as file:
content = file.read().strip()

key = file_name
os.environ[key] = content
except Exception as e:
print(f"Error reading file {file_path}: {e}")


def get_file_from_uri(uri):
"""
Get file name from uri.
Expand Down
38 changes: 38 additions & 0 deletions superduper/misc/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import time
import dataclasses as dc
import functools
import typing as t

import tenacity

import superduper as s
from superduper import logging
from superduper.misc.files import load_secrets

ExceptionTypes = t.Union[t.Type[BaseException], t.Tuple[t.Type[BaseException], ...]]

Expand Down Expand Up @@ -40,6 +43,41 @@ def __call__(self, f: t.Callable) -> t.Any:
return retrier(f)


def safe_retry(exception_to_check, retries=1, delay=1):
"""
A decorator that retries a function if a specified exception is raised.
:param exception_to_check: The exception or tuple of exceptions to check.
:param retries: The maximum number of retries.
:param delay: Delay between retries in seconds.
:return: The result of the decorated function.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
attempt = 0
while attempt < retries:
try:
load_secrets()
return func(*args, **kwargs)
except exception_to_check as e:
attempt += 1
if attempt >= retries:
logging.error(
f"Function {func.__name__} failed after {retries} retries."
)
raise
logging.warn(
f"Retrying {func.__name__} due to {e}, attempt {attempt} of {retries}..."
)
time.sleep(delay)

return wrapper

return decorator


def db_retry(connector='databackend'):
"""Helper method to retry methods with database calls.
Expand Down

0 comments on commit db35354

Please sign in to comment.