diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index add3903..5660109 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -6,6 +6,9 @@ on: push: branches: [main] +env: + OPENAI_API_KEY: sk-dummy + jobs: build: runs-on: ubuntu-latest @@ -18,7 +21,7 @@ jobs: python-version: "3.11" - name: Install dependencies run: | - python -m pip install --upgrade pip setuptools build twine + python -m pip install --upgrade pip setuptools build twine openai python -m pip install -e .[dev] - name: Test with pytest run: | diff --git a/py/autoevals/llm.py b/py/autoevals/llm.py index 9e97284..c85a8a3 100644 --- a/py/autoevals/llm.py +++ b/py/autoevals/llm.py @@ -5,7 +5,6 @@ from typing import List, Optional import chevron -import openai import yaml from .base import Score, Scorer @@ -129,16 +128,21 @@ def _render_messages(self, **kwargs): ] def _request_args(self, output, expected, **kwargs): - return dict( - Completion=openai.ChatCompletion, + ret = dict( model=self.model, - engine=self.engine, messages=self._render_messages(output=output, expected=expected, **kwargs), functions=self.classification_functions, function_call={"name": "select_choice"}, **self.extra_args, ) + if self.engine is not None: + # This parameter has been deprecated (https://help.openai.com/en/articles/6283125-what-happened-to-engines) + # and is unsupported in OpenAI v1, so only set it if the user has specified it + ret["engine"] = self.engine + + return ret + def _postprocess_response(self, resp): if len(resp["choices"]) > 0: return self._process_response(resp["choices"][0]["message"]) diff --git a/py/autoevals/oai.py b/py/autoevals/oai.py index bc91e54..a2789ef 100644 --- a/py/autoevals/oai.py +++ b/py/autoevals/oai.py @@ -2,11 +2,13 @@ import json import os import sqlite3 +import sys +import textwrap import threading import time from pathlib import Path -from .util import current_span, traced +from .util import current_span _CACHE_DIR = None _CONN = None @@ -30,52 +32,103 @@ def open_cache(): return _CONN -def log_openai_request(input_args, response, **kwargs): - span = current_span() - if not span: - return +CACHE_LOCK = threading.Lock() - input = input_args.pop("messages") - span.log( - metrics={ - "tokens": response["usage"]["total_tokens"], - "prompt_tokens": response["usage"]["prompt_tokens"], - "completion_tokens": response["usage"]["completion_tokens"], - }, - metadata={**input_args, **kwargs}, - input=input, - output=response["choices"][0], - ) +def prepare_openai_complete(is_async=False, api_key=None): + try: + import openai + except Exception as e: + print( + textwrap.dedent( + f"""\ + Unable to import openai: {e} + + Please install it, e.g. with + + pip install 'openai' + """ + ), + file=sys.stderr, + ) + raise + + openai_obj = openai + is_v1 = False + if hasattr(openai, "OpenAI"): + # This is the new v1 API + is_v1 = True + if is_async: + openai_obj = openai.AsyncOpenAI(api_key=api_key) + else: + openai_obj = openai.OpenAI(api_key=api_key) + + try: + from braintrust.oai import wrap_openai + + openai_obj = wrap_openai(openai_obj) + except ImportError: + pass + + complete_fn = None + rate_limit_error = None + if is_v1: + rate_limit_error = openai.RateLimitError + complete_fn = openai_obj.chat.completions.create + else: + rate_limit_error = openai.error.RateLimitError + if is_async: + complete_fn = openai_obj.ChatCompletion.acreate + else: + complete_fn = openai_obj.ChatCompletion.create -CACHE_LOCK = threading.Lock() + return complete_fn, rate_limit_error -@traced(name="OpenAI Completion") -def run_cached_request(Completion=None, **kwargs): - # OpenAI is very slow to import, so we only do it if we need it - import openai +def post_process_response(resp): + # This normalizes against craziness in OpenAI v0 vs. v1 + if hasattr(resp, "to_dict"): + # v0 + return resp.to_dict() + else: + # v1 + return resp.dict() + + +def log_cached_response(params, resp): + with current_span().start_span(name="OpenAI Completion") as span: + messages = params.pop("messages", None) + span.log( + metrics={ + "tokens": resp["usage"]["total_tokens"], + "prompt_tokens": resp["usage"]["prompt_tokens"], + "completion_tokens": resp["usage"]["completion_tokens"], + }, + input=messages, + output=resp["choices"], + ) - if Completion is None: - Completion = openai.Completion + +def run_cached_request(api_key=None, **kwargs): + # OpenAI is very slow to import, so we only do it if we need it + complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key) param_key = json.dumps(kwargs) conn = open_cache() with CACHE_LOCK: cursor = conn.cursor() resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() - cached = False retries = 0 if resp: - cached = True resp = json.loads(resp[0]) + log_cached_response(kwargs, resp) else: sleep_time = 0.1 while retries < 20: try: - resp = Completion.create(**kwargs).to_dict() + resp = post_process_response(complete(**kwargs)) break - except openai.error.RateLimitError: + except RateLimitError: sleep_time *= 1.5 time.sleep(sleep_time) retries += 1 @@ -85,36 +138,29 @@ def run_cached_request(Completion=None, **kwargs): cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) conn.commit() - log_openai_request(kwargs, resp, cached=cached) - return resp -@traced(name="OpenAI Completion") -async def arun_cached_request(Completion=None, **kwargs): - # OpenAI is very slow to import, so we only do it if we need it - import openai - - if Completion is None: - Completion = openai.Completion +async def arun_cached_request(api_key=None, **kwargs): + complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key) param_key = json.dumps(kwargs) conn = open_cache() with CACHE_LOCK: cursor = conn.cursor() resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() - cached = False retries = 0 if resp: resp = json.loads(resp[0]) - cached = True + log_cached_response(kwargs, resp) else: sleep_time = 0.1 while retries < 100: try: - resp = (await Completion.acreate(**kwargs)).to_dict() + resp = post_process_response(await complete(**kwargs)) break - except openai.error.RateLimitError: + except RateLimitError: + # Just assume it's a rate limit error sleep_time *= 1.5 await asyncio.sleep(sleep_time) retries += 1 @@ -124,6 +170,4 @@ async def arun_cached_request(Completion=None, **kwargs): cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) conn.commit() - log_openai_request(kwargs, resp, cached=cached, retries=retries) - return resp diff --git a/py/autoevals/util.py b/py/autoevals/util.py index 20a51a3..018479d 100644 --- a/py/autoevals/util.py +++ b/py/autoevals/util.py @@ -1,5 +1,6 @@ import dataclasses import json +import time class SerializableDataClass: @@ -12,21 +13,40 @@ def as_json(self, **kwargs): return json.dumps(self.as_dict(), **kwargs) -class NoOpSpan: - def log(self, **kwargs): +# DEVNOTE: This is copied from braintrust-sdk/py/src/braintrust/logger.py +class _NoopSpan: + def __init__(self, *args, **kwargs): pass - def start_span(self, *args, **kwargs): - return self + @property + def id(self): + return "" + + @property + def span_id(self): + return "" + + @property + def root_span_id(self): + return "" - def end(self, *args, **kwargs): + def log(self, **event): pass + def start_span(self, name, span_attributes={}, start_time=None, set_current=None, **event): + return self + + def end(self, end_time=None): + return end_time or time.time() + + def close(self, end_time=None): + return self.end(end_time) + def __enter__(self): - pass + return self - def __exit__(self, exc_type, exc_val, exc_tb): - pass + def __exit__(self, type, value, callback): + del type, value, callback def current_span(): @@ -35,7 +55,7 @@ def current_span(): return _get_current_span() except ImportError as e: - return NoOpSpan() + return _NoopSpan() def traced(*span_args, **span_kwargs): diff --git a/setup.py b/setup.py index 6143ce4..2200f29 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open(os.path.join(dir_name, "README.md"), "r", encoding="utf-8") as f: long_description = f.read() -install_requires = ["chevron", "openai==0.28.1", "levenshtein", "pyyaml"] +install_requires = ["chevron", "levenshtein", "pyyaml"] extras_require = { "dev": [