From 2fb5d6a748a0293fafbbbe5ed2efadfe8a4f1a42 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Fri, 3 Nov 2023 13:44:13 -0700 Subject: [PATCH] Disable threading checks and synchronize access with sync/async locks --- py/autoevals/oai.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/py/autoevals/oai.py b/py/autoevals/oai.py index 16ff671..b0381f0 100644 --- a/py/autoevals/oai.py +++ b/py/autoevals/oai.py @@ -2,6 +2,7 @@ import json import os import sqlite3 +import threading import time from pathlib import Path @@ -24,7 +25,7 @@ def open_cache(): if _CONN is None: oai_cache_path = Path(_CACHE_DIR) / "oai.sqlite" os.makedirs(_CACHE_DIR, exist_ok=True) - _CONN = sqlite3.connect(oai_cache_path) + _CONN = sqlite3.connect(oai_cache_path, check_same_thread=False) _CONN.execute("CREATE TABLE IF NOT EXISTS cache (params text, response text)") return _CONN @@ -47,6 +48,9 @@ def log_openai_request(input_args, response, **kwargs): ) +CACHE_LOCK = threading.Lock() + + @traced(name="OpenAI Completion") def run_cached_request(Completion=None, **kwargs): # OpenAI is very slow to import, so we only do it if we need it @@ -57,8 +61,9 @@ def run_cached_request(Completion=None, **kwargs): param_key = json.dumps(kwargs) conn = open_cache() - cursor = conn.cursor() - resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() + with CACHE_LOCK: + cursor = conn.cursor() + resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() cached = False retries = 0 if resp: @@ -75,14 +80,19 @@ def run_cached_request(Completion=None, **kwargs): time.sleep(sleep_time) retries += 1 - cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) - conn.commit() + with CACHE_LOCK: + cursor = conn.cursor() + cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) + conn.commit() log_openai_request(kwargs, resp, cached=cached) return resp +ACACHE_LOCK = asyncio.Lock() + + @traced(name="OpenAI Completion") async def arun_cached_request(Completion=None, **kwargs): # OpenAI is very slow to import, so we only do it if we need it @@ -93,8 +103,9 @@ async def arun_cached_request(Completion=None, **kwargs): param_key = json.dumps(kwargs) conn = open_cache() - cursor = conn.cursor() - resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() + with ACACHE_LOCK: + cursor = conn.cursor() + resp = cursor.execute("""SELECT response FROM "cache" WHERE params=?""", [param_key]).fetchone() cached = False retries = 0 if resp: @@ -111,8 +122,10 @@ async def arun_cached_request(Completion=None, **kwargs): await asyncio.sleep(sleep_time) retries += 1 - cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) - conn.commit() + with ACACHE_LOCK: + cursor = conn.cursor() + cursor.execute("""INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]) + conn.commit() log_openai_request(kwargs, resp, cached=cached, retries=retries)