Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Nov 28, 2023
1 parent 25da796 commit bdf4794
Showing 1 changed file with 31 additions and 50 deletions.
81 changes: 31 additions & 50 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def open_cache():
PROXY_URL = "https://braintrustproxy.com/v1"


def prepare_openai_complete(is_async=False, api_key=None):
def prepare_openai_complete(is_async=False, api_key=None, api_url=None):
if api_url is None:
api_url = PROXY_URL

try:
import openai
except Exception as e:
Expand All @@ -60,9 +63,13 @@ def prepare_openai_complete(is_async=False, api_key=None):
# This is the new v1 API
is_v1 = True
if is_async:
openai_obj = openai.AsyncOpenAI(api_key=api_key, api_url=PROXY_URL)
openai_obj = openai.AsyncOpenAI(api_key=api_key, base_url=PROXY_URL)
else:
openai_obj = openai.OpenAI(api_key=api_key, api_url=PROXY_URL)
openai_obj = openai.OpenAI(api_key=api_key, base_url=PROXY_URL)
else:
if api_key:
openai.api_key = api_key
openai.api_base = PROXY_URL

try:
from braintrust.oai import wrap_openai
Expand Down Expand Up @@ -113,62 +120,36 @@ def log_cached_response(params, resp):
def run_cached_request(api_key=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
complete, RateLimitError = prepare_openai_complete(is_async=False, api_key=api_key)
print(kwargs)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
retries = 0
if resp:
resp = json.loads(resp[0])
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 20:
try:
resp = post_process_response(complete(**kwargs))
break
except RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1

with CACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(complete(**kwargs))
break
except RateLimitError:
sleep_time *= 1.5
time.sleep(sleep_time)
retries += 1

return resp


async def arun_cached_request(api_key=None, **kwargs):
complete, RateLimitError = prepare_openai_complete(is_async=True, api_key=api_key)
print(kwargs)

param_key = json.dumps(kwargs)
conn = open_cache()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
retries = 0
if resp:
resp = json.loads(resp[0])
log_cached_response(kwargs, resp)
else:
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(await complete(**kwargs))
break
except RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
retries += 1

with CACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
sleep_time = 0.1
while retries < 100:
try:
resp = post_process_response(await complete(**kwargs))
break
except RateLimitError:
# Just assume it's a rate limit error
sleep_time *= 1.5
await asyncio.sleep(sleep_time)
retries += 1

return resp

0 comments on commit bdf4794

Please sign in to comment.