-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multiple openai versions in python #27
Changes from 5 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
import dataclasses | ||
import json | ||
import sys | ||
import textwrap | ||
import time | ||
|
||
|
||
class SerializableDataClass: | ||
|
@@ -12,21 +15,40 @@ def as_json(self, **kwargs): | |
return json.dumps(self.as_dict(), **kwargs) | ||
|
||
|
||
class NoOpSpan: | ||
def log(self, **kwargs): | ||
# DEVNOTE: This is copied from braintrust-sdk/py/src/braintrust/logger.py | ||
class _NoopSpan: | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
|
||
def start_span(self, *args, **kwargs): | ||
return self | ||
@property | ||
def id(self): | ||
return "" | ||
|
||
@property | ||
def span_id(self): | ||
return "" | ||
|
||
def end(self, *args, **kwargs): | ||
@property | ||
def root_span_id(self): | ||
return "" | ||
|
||
def log(self, **event): | ||
pass | ||
|
||
def start_span(self, name, span_attributes={}, start_time=None, set_current=None, **event): | ||
return self | ||
|
||
def end(self, end_time=None): | ||
return end_time or time.time() | ||
|
||
def close(self, end_time=None): | ||
return self.end(end_time) | ||
|
||
def __enter__(self): | ||
pass | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
pass | ||
def __exit__(self, type, value, callback): | ||
del type, value, callback | ||
|
||
|
||
def current_span(): | ||
|
@@ -35,7 +57,7 @@ def current_span(): | |
|
||
return _get_current_span() | ||
except ImportError as e: | ||
return NoOpSpan() | ||
return _NoopSpan() | ||
|
||
|
||
def traced(*span_args, **span_kwargs): | ||
|
@@ -48,3 +70,53 @@ def traced(*span_args, **span_kwargs): | |
return span_args[0] | ||
else: | ||
return lambda f: f | ||
|
||
|
||
def prepare_openai_complete(is_async=False, api_key=None): | ||
try: | ||
import openai | ||
except Exception as e: | ||
print( | ||
textwrap.dedent( | ||
f"""\ | ||
Unable to import openai: {e} | ||
|
||
Please install it, e.g. with | ||
|
||
pip install 'openai' | ||
""" | ||
), | ||
file=sys.stderr, | ||
) | ||
raise | ||
|
||
openai_obj = openai | ||
is_v1 = False | ||
if hasattr(openai, "chat") and hasattr(openai.chat, "completions"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we expecting that someone has already imported openai and set
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah weird, I changed it to be |
||
# This is the new v1 API | ||
is_v1 = True | ||
if is_async: | ||
openai_obj = openai.AsyncOpenAI(api_key=api_key) | ||
else: | ||
openai_obj = openai.OpenAI(api_key=api_key) | ||
|
||
try: | ||
from braintrust.oai import wrap_openai | ||
|
||
openai_obj = wrap_openai(openai_obj) | ||
except ImportError: | ||
pass | ||
|
||
complete_fn = None | ||
rate_limit_error = None | ||
if is_v1: | ||
rate_limit_error = openai.RateLimitError | ||
complete_fn = openai_obj.chat.completions.create | ||
else: | ||
rate_limit_error = openai.error.RateLimitError | ||
if is_async: | ||
complete_fn = openai_obj.ChatCompletion.acreate | ||
else: | ||
complete_fn = openai_obj.ChatCompletion.create | ||
|
||
return complete_fn, rate_limit_error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do something where we add a version number to the matrix and pin that version number here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No i think we want the test to fail if openai releases another breaking change :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I just think we could also make sure it's tested across all the versions we do support. Perhaps there's a way to encode "no version" as one of the options?