Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add secrets reload withs safe retry #2702

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support `<var:template_staged_file>` in Template to enable apps to use files/folders included in the template
- Add Data Component for storing data directly in the template
- Add a standalone flag in Streamlit to mark the page as independent.
- Add secrets directory mount for loading secret env vars.

#### Bug Fixes

Expand Down
3 changes: 2 additions & 1 deletion plugins/openai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [

[project.optional-dependencies]
test = [
"vcrpy>=5.1.0",
"vcrpy==5.1.0",
"urllib3==2.2.3",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion plugins/openai/superduper_openai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .model import OpenAIChatCompletion, OpenAIEmbedding

__version__ = "0.4.2"
__version__ = "0.4.3"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work.


__all__ = 'OpenAIChatCompletion', 'OpenAIEmbedding'
26 changes: 15 additions & 11 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
)
from openai._types import NOT_GIVEN
from superduper.backends.query_dataset import QueryDataset
from superduper.base import exceptions
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.misc.compat import cache
from superduper.misc.retry import Retry
from superduper.misc.retry import Retry, safe_retry

retry = Retry(
exception_types=(
Expand Down Expand Up @@ -54,30 +55,33 @@ def __post_init__(self, db, example):
super().__post_init__(db, example)

assert isinstance(self.client_kwargs, dict)

if self.openai_api_key is not None:
self.client_kwargs['api_key'] = self.openai_api_key
if self.openai_api_base is not None:
self.client_kwargs['base_url'] = self.openai_api_base
self.client_kwargs['default_headers'] = self.openai_api_base

@safe_retry(exceptions.MissingSecretsException, verbose=0)
def init(self):
"""Initialize the model."""
super().init()

# dall-e is not currently included in list returned by OpenAI model endpoint
if 'OPENAI_API_KEY' not in os.environ or (
'api_key' not in self.client_kwargs.keys() and self.client_kwargs
):
raise exceptions.MissingSecretsException(
'OPENAI_API_KEY not available neither in environment vars '
'nor in `client_kwargs`'
)

if self.model not in (
mo := _available_models(json.dumps(self.client_kwargs))
) and self.model not in ('dall-e'):
msg = f'model {self.model} not in OpenAI available models, {mo}'
raise ValueError(msg)

self.syncClient = SyncOpenAI(**self.client_kwargs)

if 'OPENAI_API_KEY' not in os.environ and (
'api_key' not in self.client_kwargs.keys() and self.client_kwargs
):
raise ValueError(
'OPENAI_API_KEY not available neither in environment vars '
'nor in `client_kwargs`'
)

def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Predict on a dataset.

Expand Down
4 changes: 3 additions & 1 deletion superduper/base/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Configuration variables for superduper.io.

The classes in this file define the configuration variables for superduper.io,
which means that this file gets imported before alost anything else, and
hich means that this file gets imported before alost anything else, and
canot contain any other imports from this project.
"""

Expand Down Expand Up @@ -148,6 +148,7 @@ class Config(BaseConfig):

:param envs: The envs datas
:param data_backend: The URI for the data backend
:param secrets_volume: The secrets volume mount for secrets env vars.
:param lance_home: The home directory for the Lance vector indices,
Default: .superduper/vector_indices
:param artifact_store: The URI for the artifact store
Expand All @@ -174,6 +175,7 @@ class Config(BaseConfig):
envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None

data_backend: str = "mongodb://localhost:27017/test_db"
secrets_volume: str = os.path.join(".superduper", "/session/secrets")

lance_home: str = os.path.join(".superduper", "vector_indices")

Expand Down
8 changes: 8 additions & 0 deletions superduper/base/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ class UnsupportedDatatype(BaseException):

:param msg: msg for BaseException
"""


class MissingSecretsException(BaseException):
"""
Missing secrets.

:param msg: msg for BaseException
"""
20 changes: 20 additions & 0 deletions superduper/misc/files.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import hashlib
import os

from superduper import CFG


def load_secrets():
"""Help method to load secrets from directory."""
secrets_dir = CFG.secrets_volume
if not os.path.isdir(secrets_dir):
raise ValueError(f"The path '{secrets_dir}' is not a valid directory.")

for root, _, files in os.walk(secrets_dir):
for file_name in files:
file_path = os.path.join(root, file_name)
try:
with open(file_path, 'r') as file:
content = file.read().strip()

key = file_name
os.environ[key] = content
except Exception as e:
print(f"Error reading file {file_path}: {e}")


def get_file_from_uri(uri):
"""
Get file name from uri.
Expand Down
44 changes: 44 additions & 0 deletions superduper/misc/retry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import dataclasses as dc
import functools
import time
import typing as t

import tenacity

import superduper as s
from superduper import logging
from superduper.misc.files import load_secrets

ExceptionTypes = t.Union[t.Type[BaseException], t.Tuple[t.Type[BaseException], ...]]

Expand Down Expand Up @@ -40,6 +43,47 @@ def __call__(self, f: t.Callable) -> t.Any:
return retrier(f)


def safe_retry(exception_to_check, retries=1, delay=0.3, verbose=1):
"""
A decorator that retries a function if a specified exception is raised.

:param exception_to_check: The exception or tuple of exceptions to check.
:param retries: The maximum number of retries.
:param delay: Delay between retries in seconds.
:param verbose: Verbose for logs.
:return: The result of the decorated function.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
attempt = 0
while attempt <= retries:
try:
if attempt >= 1:
load_secrets()
return func(*args, **kwargs)
except exception_to_check as e:
attempt += 1
if attempt > retries:
if verbose:
logging.error(
f"Function {func.__name__} failed ",
"after {retries} retries.",
)
raise
if verbose:
logging.warn(
f"Retrying {func.__name__} due to {e}"
", attempt {attempt} of {retries}..."
)
time.sleep(delay)

return wrapper

return decorator


def db_retry(connector='databackend'):
"""Helper method to retry methods with database calls.

Expand Down
Loading