Skip to content

Commit

Permalink
[PLUGINS] Bump Version [openai]
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 27, 2024
1 parent c556e6e commit a2ebc17
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 26 deletions.
3 changes: 2 additions & 1 deletion plugins/openai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [

[project.optional-dependencies]
test = [
"vcrpy>=5.1.0",
"vcrpy==5.1.0",
"urllib3==2.2.3",
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion plugins/openai/superduper_openai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .model import OpenAIChatCompletion, OpenAIEmbedding

__version__ = "0.4.2"
__version__ = "0.4.3"

__all__ = 'OpenAIChatCompletion', 'OpenAIEmbedding'
23 changes: 11 additions & 12 deletions plugins/openai/superduper_openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
)
from openai._types import NOT_GIVEN
from superduper.backends.query_dataset import QueryDataset
from superduper.base import exceptions
from superduper.base.datalayer import Datalayer
from superduper.components.model import APIBaseModel, Inputs
from superduper.misc.compat import cache
from superduper.misc.retry import Retry, safe_retry
from superduper.base import exceptions

retry = Retry(
exception_types=(
Expand Down Expand Up @@ -55,34 +55,33 @@ def __post_init__(self, db, example):
super().__post_init__(db, example)

assert isinstance(self.client_kwargs, dict)

if self.openai_api_key is not None:
self.client_kwargs['api_key'] = self.openai_api_key
if self.openai_api_base is not None:
self.client_kwargs['base_url'] = self.openai_api_base
self.client_kwargs['default_headers'] = self.openai_api_base

@safe_retry(exceptions.MissingSecretsException)
def init(self, db=None):
@safe_retry(exceptions.MissingSecretsException, verbose=0)
def init(self):
"""Initialize the model."""
super().init()

# dall-e is not currently included in list returned by OpenAI model endpoint
if self.model not in (
mo := _available_models(json.dumps(self.client_kwargs))
) and self.model not in ('dall-e'):
msg = f'model {self.model} not in OpenAI available models, {mo}'
raise ValueError(msg)
self.syncClient = SyncOpenAI(**self.client_kwargs)

if 'OPENAI_API_KEY' not in os.environ and (
if 'OPENAI_API_KEY' not in os.environ or (
'api_key' not in self.client_kwargs.keys() and self.client_kwargs
):
raise exceptions.MissingSecretsException(
'OPENAI_API_KEY not available neither in environment vars '
'nor in `client_kwargs`'
)

if self.model not in (
mo := _available_models(json.dumps(self.client_kwargs))
) and self.model not in ('dall-e'):
msg = f'model {self.model} not in OpenAI available models, {mo}'
raise ValueError(msg)
self.syncClient = SyncOpenAI(**self.client_kwargs)

def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Predict on a dataset.
Expand Down
3 changes: 2 additions & 1 deletion superduper/base/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Configuration variables for superduper.io.
The classes in this file define the configuration variables for superduper.io,
which means that this file gets imported before alost anything else, and
hich means that this file gets imported before alost anything else, and
canot contain any other imports from this project.
"""

Expand Down Expand Up @@ -148,6 +148,7 @@ class Config(BaseConfig):
:param envs: The envs datas
:param data_backend: The URI for the data backend
:param secrets_volume: The secrets volume mount for secrets env vars.
:param lance_home: The home directory for the Lance vector indices,
Default: .superduper/vector_indices
:param artifact_store: The URI for the artifact store
Expand Down
27 changes: 16 additions & 11 deletions superduper/misc/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,40 @@ def __call__(self, f: t.Callable) -> t.Any:
return retrier(f)


def safe_retry(exception_to_check, retries=1, delay=1):
def safe_retry(exception_to_check, retries=1, delay=0.3, verbose=1):
"""
A decorator that retries a function if a specified exception is raised.
:param exception_to_check: The exception or tuple of exceptions to check.
:param retries: The maximum number of retries.
:param delay: Delay between retries in seconds.
:param verbose: Verbose for logs.
:return: The result of the decorated function.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
attempt = 0
while attempt < retries:
while attempt <= retries:
try:
load_secrets()
if attempt >= 1:
load_secrets()
return func(*args, **kwargs)
except exception_to_check as e:
attempt += 1
if attempt >= retries:
logging.error(
f"Function {func.__name__} failed after {retries} retries."
)
if attempt > retries:
if verbose:
logging.error(
f"Function {func.__name__} failed ",
"after {retries} retries.",
)
raise
logging.warn(
f"Retrying {func.__name__} due to {e}"
", attempt {attempt} of {retries}..."
)
if verbose:
logging.warn(
f"Retrying {func.__name__} due to {e}"
", attempt {attempt} of {retries}..."
)
time.sleep(delay)

return wrapper
Expand Down

0 comments on commit a2ebc17

Please sign in to comment.