Skip to content

Commit

Permalink
Set the braintrust proxy as the api url to facilitate caching (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl authored Nov 28, 2023
1 parent 5c836de commit 975fce7
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 109 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/js.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
name: js

env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

on:
pull_request:
push:
Expand All @@ -21,6 +24,5 @@ jobs:
node-version: ${{ matrix.node-version }}
cache: "npm"
- run: npm install
# @TODO: Re-enable this once cache is integrated with JS tests.
# - run: npm run test
- run: npm run test
- run: npm run build
6 changes: 3 additions & 3 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Modified from https://github.com/actions/starter-workflows/blob/main/ci/python-app.yml
name: python

env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

on:
pull_request:
push:
branches: [main]

env:
OPENAI_API_KEY: sk-dummy

jobs:
build:
runs-on: ubuntu-latest
Expand Down
Binary file removed .testcache/oai.sqlite
Binary file not shown.
3 changes: 3 additions & 0 deletions js/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ interface LLMArgs {
temperature?: number;
openAiApiKey?: string;
openAiOrganizationId?: string;
openAiBaseUrl?: string;
}

const PLAIN_RESPONSE_SCHEMA = {
Expand Down Expand Up @@ -80,6 +81,7 @@ export async function OpenAIClassifier<RenderArgs, Output>(
expected,
openAiApiKey,
openAiOrganizationId,
openAiBaseUrl,
...remaining
} = args;

Expand Down Expand Up @@ -138,6 +140,7 @@ export async function OpenAIClassifier<RenderArgs, Output>(
cache,
openAiApiKey,
openAiOrganizationId,
openAiBaseUrl,
}
);

Expand Down
6 changes: 5 additions & 1 deletion js/oai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ export interface ChatCache {
export interface OpenAIAuth {
openAiApiKey?: string;
openAiOrganizationId?: string;
openAiBaseUrl?: string;
}

const PROXY_URL = "https://braintrustproxy.com/v1";

export async function cachedChatCompletion(
params: CachedLLMParams,
options: { cache?: ChatCache } & OpenAIAuth
): Promise<ChatCompletion> {
const { cache, openAiApiKey, openAiOrganizationId } = options;
const { cache, openAiApiKey, openAiOrganizationId, openAiBaseUrl } = options;

return await currentSpan().traced("OpenAI Completion", async (span: any) => {
let cached = false;
Expand All @@ -42,6 +45,7 @@ export async function cachedChatCompletion(
const openai = new OpenAI({
apiKey: openAiApiKey || Env.OPENAI_API_KEY,
organization: openAiOrganizationId,
baseURL: openAiBaseUrl || PROXY_URL,
});

if (openai === null) {
Expand Down
11 changes: 10 additions & 1 deletion py/autoevals/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
temperature=None,
engine=None,
api_key=None,
base_url=None,
):
self.name = name
self.model = model
Expand All @@ -92,6 +93,8 @@ def __init__(
self.extra_args["max_tokens"] = max(max_tokens, 5)
if api_key:
self.extra_args["api_key"] = api_key
if base_url:
self.extra_args["base_url"] = base_url

self.render_args = {}
if render_args:
Expand Down Expand Up @@ -199,6 +202,7 @@ def __init__(
temperature=0,
engine=None,
api_key=None,
base_url=None,
):
choice_strings = list(choice_scores.keys())

Expand All @@ -220,6 +224,7 @@ def __init__(
temperature=temperature,
engine=engine,
api_key=api_key,
base_url=base_url,
render_args={"__choices": choice_strings},
)

Expand All @@ -235,7 +240,9 @@ def from_spec_file(cls, name: str, path: str, **kwargs):


class SpecFileClassifier(LLMClassifier):
def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None):
def __new__(
cls, model=None, engine=None, use_cot=None, max_tokens=None, temperature=None, api_key=None, base_url=None
):
kwargs = {}
if model is not None:
kwargs["model"] = model
Expand All @@ -249,6 +256,8 @@ def __new__(cls, model=None, engine=None, use_cot=None, max_tokens=None, tempera
kwargs["temperature"] = temperature
if api_key is not None:
kwargs["api_key"] = api_key
if base_url is not None:
kwargs["base_url"] = base_url

# convert FooBar to foo_bar
template_name = re.sub(r"(?<!^)(?=[A-Z])", "_", cls.__name__).lower()
Expand Down
127 changes: 33 additions & 94 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,18 @@
import asyncio
import json
import os
import sqlite3
import sys
import textwrap
import threading
import time
from pathlib import Path

from .util import current_span

_CACHE_DIR = None
_CONN = None
PROXY_URL = "https://braintrustproxy.com/v1"


def set_cache_dir(path):
global _CACHE_DIR
_CACHE_DIR = path
def prepare_openai_complete(is_async=False, api_key=None, base_url=None):
if base_url is None:
base_url = PROXY_URL


def open_cache():
global _CACHE_DIR, _CONN
if _CACHE_DIR is None:
_CACHE_DIR = Path.home() / ".cache" / "braintrust"

if _CONN is None:
oai_cache_path = Path(_CACHE_DIR) / "oai.sqlite"
os.makedirs(_CACHE_DIR, exist_ok=True)
_CONN = sqlite3.connect(oai_cache_path, check_same_thread=False)
_CONN.execute("CREATE TABLE IF NOT EXISTS cache (params text, response text)")
return _CONN


CACHE_LOCK = threading.Lock()


def prepare_openai_complete(is_async=False, api_key=None):
try:
import openai
except Exception as e:
Expand All @@ -59,9 +36,13 @@ def prepare_openai_complete(is_async=False, api_key=None):
# This is the new v1 API
is_v1 = True
if is_async:
openai_obj = openai.AsyncOpenAI(api_key=api_key)
openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=PROXY_URL)
else:
openai_obj = openai.OpenAI(api_key=api_key)
openai_obj = openai.OpenAI(api_key=api_key, base_url=PROXY_URL)
else:
if api_key:
openai.api_key = api_key
openai.api_base = PROXY_URL

try:
from braintrust.oai import wrap_openai
Expand Down Expand Up @@ -95,79 +76,37 @@ def post_process_response(resp):
return resp.dict()


def log_cached_response(params, resp):
with current_span().start_span(name="OpenAI Completion") as span:
messages = params.pop("messages", None)
span.log(
metrics={
"tokens": resp["usage"]["total_tokens"],
"prompt_tokens": resp["usage"]["prompt_tokens"],
"completion_tokens": resp["usage"]["completion_tokens"],
},
input=messages,
output=resp["choices"],
)


def run_cached_request(api_key=None, **kwargs):
def run_cached_request(api_key=None, base_url=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key)
complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key, base_url=base_url)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
retries = 0
if resp:
resp = json.loads(resp[0])
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 20:
try:
resp = post_process_response(complete(**kwargs))
break
except RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1

with CACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(complete(**kwargs))
break
except RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1

return resp


async def arun_cached_request(api_key=None, **kwargs):
complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key)
async def arun_cached_request(api_key=None, base_url=None, **kwargs):
complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key, base_url=base_url)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
retries = 0
if resp:
resp = json.loads(resp[0])
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(await complete(**kwargs))
break
except RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
retries += 1

with CACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(await complete(**kwargs))
break
except RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
retries += 1

return resp
9 changes: 1 addition & 8 deletions py/autoevals/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,8 @@

import chevron

from autoevals.llm import build_classification_functions
from autoevals.oai import set_cache_dir

# By default, we use the user's tmp cache directory (e.g. in the Library/Caches dir on macOS)
# However, we'd like to cache (and commit) the results of our tests, so we monkey patch the library
# to use a cache directory in the project root.
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
set_cache_dir(os.path.join(_SCRIPT_DIR, "../../.testcache"))
from autoevals.llm import *
from autoevals.llm import build_classification_functions


def test_template_html():
Expand Down

0 comments on commit 975fce7

Please sign in to comment.