Skip to content

Commit

Permalink
Disable threading checks and synchronize access with sync/async locks
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Nov 3, 2023
1 parent b80c9bd commit 2fb5d6a
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions py/autoevals/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sqlite3
import threading
import time
from pathlib import Path

Expand All @@ -24,7 +25,7 @@ def open_cache():
if _CONN is None:
oai_cache_path = Path(_CACHE_DIR) / "oai.sqlite"
os.makedirs(_CACHE_DIR, exist_ok=True)
_CONN = sqlite3.connect(oai_cache_path)
_CONN = sqlite3.connect(oai_cache_path, check_same_thread=False)
_CONN.execute("CREATE TABLE IF NOT EXISTS cache (params text, response text)")
return _CONN

Expand All @@ -47,6 +48,9 @@ def log_openai_request(input_args, response, **kwargs):
)


CACHE_LOCK = threading.Lock()


@traced(name="OpenAI Completion")
def run_cached_request(Completion=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
Expand All @@ -57,8 +61,9 @@ def run_cached_request(Completion=None, **kwargs):

param_key = json.dumps(kwargs)
conn = open_cache()
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
with CACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
cached = False
retries = 0
if resp:
Expand All @@ -75,14 +80,19 @@ def run_cached_request(Completion=None, **kwargs):
time.sleep(sleep_time)
retries += 1

cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
with CACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()

log_openai_request(kwargs, resp, cached=cached)

return resp


ACACHE_LOCK = asyncio.Lock()


@traced(name="OpenAI Completion")
async def arun_cached_request(Completion=None, **kwargs):
# OpenAI is very slow to import, so we only do it if we need it
Expand All @@ -93,8 +103,9 @@ async def arun_cached_request(Completion=None, **kwargs):

param_key = json.dumps(kwargs)
conn = open_cache()
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
with ACACHE_LOCK:
cursor = conn.cursor()
resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone()
cached = False
retries = 0
if resp:
Expand All @@ -111,8 +122,10 @@ async def arun_cached_request(Completion=None, **kwargs):
await asyncio.sleep(sleep_time)
retries += 1

cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()
with ACACHE_LOCK:
cursor = conn.cursor()
cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)])
conn.commit()

log_openai_request(kwargs, resp, cached=cached, retries=retries)

Expand Down

0 comments on commit 2fb5d6a

Please sign in to comment.