Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple openai versions in python #27

Merged
merged 6 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
push:
branches: [main]

env:
OPENAI_API_KEY: sk-dummy

jobs:
build:
runs-on: ubuntu-latest
Expand All @@ -18,7 +21,7 @@ jobs:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools build twine
python -m pip install --upgrade pip setuptools build twine openai
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do something where we add a version number to the matrix and pin that version number here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No i think we want the test to fail if openai releases another breaking change :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I just think we could also make sure it's tested across all the versions we do support. Perhaps there's a way to encode "no version" as one of the options?

python -m pip install -e .[dev]
- name: Test with pytest
run: |
Expand Down
12 changes: 8 additions & 4 deletions py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Optional

import chevron
import openai
import yaml

from .base import Score, Scorer
Expand Down Expand Up @@ -129,16 +128,21 @@ def _render_messages(self, **kwargs):
]

def _request_args(self, output, expected, **kwargs):
return dict(
Completion=openai.ChatCompletion,
ret = dict(
model=self.model,
engine=self.engine,
messages=self._render_messages(output=output, expected=expected, **kwargs),
functions=self.classification_functions,
function_call={"name": "select_choice"},
**self.extra_args,
)

if self.engine is not None:
# This parameter has been deprecated (https://help.openai.com/en/articles/6283125-what-happened-to-engines)
# and is unsupported in OpenAI v1, so only set it if the user has specified it
ret["engine"] = self.engine

return ret

def _postprocess_response(self, resp):
if len(resp["choices"]) > 0:
return self._process_response(resp["choices"][0]["message"])
Expand Down
76 changes: 34 additions & 42 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from pathlib import Path

from .util import current_span, traced
from .util import current_span, prepare_openai_complete

_CACHE_DIR = None
_CONN = None
Expand All @@ -30,52 +30,53 @@ def open_cache():
return _CONN


def log_openai_request(input_args, response, **kwargs):
span = current_span()
if not span:
return
CACHE_LOCK = threading.Lock()

input = input_args.pop("messages")
span.log(
metrics={
"tokens": response["usage"]["total_tokens"],
"prompt_tokens": response["usage"]["prompt_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
},
metadata={**input_args, **kwargs},
input=input,
output=response["choices"][0],
)

def post_process_response(resp):
# This normalizes against craziness in OpenAI v0 vs. v1
if hasattr(resp, "to_dict"):
# v0
return resp.to_dict()
else:
# v1
return resp.dict()

CACHE_LOCK = threading.Lock()

def log_cached_response(params, resp):
with current_span().start_span(name="OpenAI Completion") as span:
messages = params.pop("messages", None)
span.log(
metrics={
"tokens": resp["usage"]["total_tokens"],
"prompt_tokens": resp["usage"]["prompt_tokens"],
"completion_tokens": resp["usage"]["completion_tokens"],
},
input=messages,
output=resp["choices"],
)

@traced(name="OpenAI Completion")
def run_cached_request(Completion=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
import openai

if Completion is None:
Completion = openai.Completion
def run_cached_request(api_key=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
cached = False
retries = 0
if resp:
cached = True
resp = json.loads(resp[0])
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 20:
try:
resp = Completion.create(**kwargs).to_dict()
resp = post_process_response(complete(**kwargs))
break
except openai.error.RateLimitError:
except RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1
Expand All @@ -85,36 +86,29 @@ def run_cached_request(Completion=None, **kwargs):
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()

log_openai_request(kwargs, resp, cached=cached)

return resp


@traced(name="OpenAI Completion")
async def arun_cached_request(Completion=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
import openai

if Completion is None:
Completion = openai.Completion
async def arun_cached_request(api_key=None, **kwargs):
complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
cached = False
retries = 0
if resp:
resp = json.loads(resp[0])
cached = True
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 100:
try:
resp = (await Completion.acreate(**kwargs)).to_dict()
resp = post_process_response(await complete(**kwargs))
break
except openai.error.RateLimitError:
except RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
retries += 1
Expand All @@ -124,6 +118,4 @@ async def arun_cached_request(Completion=None, **kwargs):
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()

log_openai_request(kwargs, resp, cached=cached, retries=retries)

return resp
90 changes: 81 additions & 9 deletions py/autoevals/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import dataclasses
import json
import sys
import textwrap
import time


class SerializableDataClass:
Expand All @@ -12,21 +15,40 @@ def as_json(self, **kwargs):
return json.dumps(self.as_dict(), **kwargs)


class NoOpSpan:
def log(self, **kwargs):
# DEVNOTE: This is copied from braintrust-sdk/py/src/braintrust/logger.py
class _NoopSpan:
def __init__(self, *args, **kwargs):
pass

def start_span(self, *args, **kwargs):
return self
@property
def id(self):
return ""

@property
def span_id(self):
return ""

def end(self, *args, **kwargs):
@property
def root_span_id(self):
return ""

def log(self, **event):
pass

def start_span(self, name, span_attributes={}, start_time=None, set_current=None, **event):
return self

def end(self, end_time=None):
return end_time or time.time()

def close(self, end_time=None):
return self.end(end_time)

def __enter__(self):
pass
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __exit__(self, type, value, callback):
del type, value, callback


def current_span():
Expand All @@ -35,7 +57,7 @@ def current_span():

return _get_current_span()
except ImportError as e:
return NoOpSpan()
return _NoopSpan()


def traced(*span_args, **span_kwargs):
Expand All @@ -48,3 +70,53 @@ def traced(*span_args, **span_kwargs):
return span_args[0]
else:
return lambda f: f


def prepare_openai_complete(is_async=False, api_key=None):
try:
import openai
except Exception as e:
print(
textwrap.dedent(
f"""\
Unable to import openai: {e}

Please install it, e.g. with

pip install 'openai'
"""
),
file=sys.stderr,
)
raise

openai_obj = openai
is_v1 = False
if hasattr(openai, "chat") and hasattr(openai.chat, "completions"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we expecting that someone has already imported openai and set openai.api_key before getting here? Because if I just import openai from scratch and call `hasattr(openai.chat, "completions"), I see the following error:

     91     api_key = os.environ.get("OPENAI_API_KEY")
     92 if api_key is None:
---> 93     raise OpenAIError(
     94         "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
     95     )
     96 self.api_key = api_key
     98 if organization is None:

OpenAIError: The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah weird, I changed it to be hasattr(openai, "OpenAI").

# This is the new v1 API
is_v1 = True
if is_async:
openai_obj = openai.AsyncOpenAI(api_key=api_key)
else:
openai_obj = openai.OpenAI(api_key=api_key)

try:
from braintrust.oai import wrap_openai

openai_obj = wrap_openai(openai_obj)
except ImportError:
pass

complete_fn = None
rate_limit_error = None
if is_v1:
rate_limit_error = openai.RateLimitError
complete_fn = openai_obj.chat.completions.create
else:
rate_limit_error = openai.error.RateLimitError
if is_async:
complete_fn = openai_obj.ChatCompletion.acreate
else:
complete_fn = openai_obj.ChatCompletion.create

return complete_fn, rate_limit_error
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
with open(os.path.join(dir_name, "README.md"), "r", encoding="utf-8") as f:
long_description = f.read()

install_requires = ["chevron", "openai==0.28.1", "levenshtein", "pyyaml"]
install_requires = ["chevron", "levenshtein", "pyyaml"]

extras_require = {
"dev": [
Expand Down