diff --git a/.gitignore b/.gitignore index 92ced80..8e7e39b 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ poetry.toml pyrightconfig.json # End of https://www.toptal.com/developers/gitignore/api/python -n \ No newline at end of file +n +.envrc \ No newline at end of file diff --git a/dialog_lib/agents/abstract.py b/dialog_lib/agents/abstract.py index b91ce78..e5fda2d 100644 --- a/dialog_lib/agents/abstract.py +++ b/dialog_lib/agents/abstract.py @@ -12,6 +12,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from langchain.chains.conversation.memory import ConversationBufferMemory +from dialog_lib.db import get_session from dialog_lib.db.memory import CustomPostgresChatMessageHistory, get_memory_instance from dialog_lib.embeddings.retrievers import DialogRetriever @@ -24,7 +25,7 @@ def __init__( parent_session_id=None, dataset=None, llm_api_key=None, - dbsession=None, + dbsession=get_session(), ): """ :param config: Configuration dictionary diff --git a/dialog_lib/db/__init__.py b/dialog_lib/db/__init__.py index 5a9fdad..2293caa 100644 --- a/dialog_lib/db/__init__.py +++ b/dialog_lib/db/__init__.py @@ -4,3 +4,4 @@ add_user_message_to_message_history, get_messages, ) +from .session import get_session \ No newline at end of file diff --git a/dialog_lib/db/memory.py b/dialog_lib/db/memory.py index f1422cc..a12d911 100644 --- a/dialog_lib/db/memory.py +++ b/dialog_lib/db/memory.py @@ -1,5 +1,5 @@ import psycopg - +from .session import get_session from langchain_postgres import PostgresChatMessageHistory from langchain.schema.messages import BaseMessage, _message_to_dict @@ -15,7 +15,7 @@ def __init__( self, *args, parent_session_id=None, - dbsession=None, + dbsession=get_session(), chats_model=Chat, chat_messages_model=ChatMessages, ssl_mode=None, @@ -67,7 +67,7 @@ def add_message(self, message: BaseMessage) -> None: def generate_memory_instance( session_id, parent_session_id=None, - dbsession=None, + dbsession=get_session(), database_url=None, chats_model=Chat, chat_messages_model=ChatMessages, @@ -88,7 +88,7 @@ def generate_memory_instance( def add_user_message_to_message_history( - session_id, message, memory=None, dbsession=None, database_url=None + session_id, message, memory=None, dbsession=get_session(), database_url=None ): """ Add a user message to the message history and returns the updated @@ -103,7 +103,7 @@ def add_user_message_to_message_history( return memory -def get_messages(session_id, dbsession=None, database_url=None): +def get_messages(session_id, dbsession=get_session(), database_url=None): """ Get all messages for a given session_id """ diff --git a/dialog_lib/db/session.py b/dialog_lib/db/session.py new file mode 100644 index 0000000..0dc6d34 --- /dev/null +++ b/dialog_lib/db/session.py @@ -0,0 +1,10 @@ +import os + +import sqlalchemy as sa +from sqlalchemy.orm import Session + +def get_session(): + engine = sa.create_engine(os.environ.get("DATABASE_URL")) + session = Session(engine) + yield session + session.close() \ No newline at end of file diff --git a/dialog_lib/db/utils.py b/dialog_lib/db/utils.py index eaf4bc3..b7ba645 100644 --- a/dialog_lib/db/utils.py +++ b/dialog_lib/db/utils.py @@ -1,9 +1,9 @@ import uuid - +from .session import get_session from .models import Chat -def create_chat_session(identifier=None, dbsession=None, model=Chat): +def create_chat_session(identifier=None, dbsession=get_session(), model=Chat): if identifier is None: identifier = uuid.uuid4().hex diff --git a/dialog_lib/loaders/csv.py b/dialog_lib/loaders/csv.py index 7ea223a..706685b 100644 --- a/dialog_lib/loaders/csv.py +++ b/dialog_lib/loaders/csv.py @@ -1,3 +1,4 @@ +from dialog_lib.db import get_session from dialog_lib.db.models import CompanyContent from dialog_lib.embeddings.generate import generate_embedding @@ -6,9 +7,10 @@ def load_csv( - file_path, dbsession, embeddings_model_instance=None, + file_path, dbsession=get_session(), embeddings_model_instance=None, embedding_llm_model=None, embedding_llm_api_key=None, company_id=None ): + loader = CSVLoader(file_path=file_path) contents = loader.load() diff --git a/dialog_lib/loaders/web.py b/dialog_lib/loaders/web.py index a05c89d..d1f825c 100644 --- a/dialog_lib/loaders/web.py +++ b/dialog_lib/loaders/web.py @@ -1,10 +1,10 @@ from dialog_lib.db.models import CompanyContent from dialog_lib.embeddings.generate import generate_embedding - +from dialog_lib.db import get_session from langchain_community.document_loaders import WebBaseLoader -def load_webpage(url, dbsession, embeddings_model_instance, company_id=None): +def load_webpage(url, embeddings_model_instance, dbsession=get_session(), company_id=None): loader = WebBaseLoader(url) contents = loader.load() diff --git a/dialog_lib/manage.py b/dialog_lib/manage.py index 0be25d2..898e7ef 100644 --- a/dialog_lib/manage.py +++ b/dialog_lib/manage.py @@ -76,7 +76,6 @@ def anthropic(model, temperature, llm_api_key, prompt, debug): @click.option("--llm-api-key", default=get_llm_key(), help="The LLM API key", required=True) @click.option("--file", help="The CSV file to load the data from", required=True) def load_csv(database_url, llm_api_key, file): - breakpoint() engine = create_engine(database_url) dbsession = Session(engine.connect()) csv_loader( @@ -86,6 +85,7 @@ def load_csv(database_url, llm_api_key, file): embedding_llm_api_key=llm_api_key ) click.echo("## Loaded the CSV file to the database") + dbsession.close() def main(): diff --git a/dialog_lib/tests/agents/test_abstract_agents.py b/dialog_lib/tests/agents/test_abstract_agents.py index cba828e..16446c7 100644 --- a/dialog_lib/tests/agents/test_abstract_agents.py +++ b/dialog_lib/tests/agents/test_abstract_agents.py @@ -20,7 +20,6 @@ def test_abstract_agent_with_valid_config(): assert agent.dataset is None assert agent.llm_api_key is None assert agent.parent_session_id is None - assert agent.dbsession is None def test_abstract_agent_get_prompt(): config = { diff --git a/dialog_lib/tests/loaders/test_web_loader.py b/dialog_lib/tests/loaders/test_web_loader.py index f7b7006..ce271dc 100644 --- a/dialog_lib/tests/loaders/test_web_loader.py +++ b/dialog_lib/tests/loaders/test_web_loader.py @@ -9,7 +9,7 @@ def test_load_web_content(mock_aioresponse, db_session, mocker): mocker.patch('dialog_lib.loaders.web.generate_embedding', return_value=[0] * 1536) mock_aioresponse.get('http://example.com', body='Hello, world!') - content = load_webpage('http://example.com', db_session, None, 1) + content = load_webpage('http://example.com', None, db_session, 1) assert content.question == "Example Domain" assert content.embedding == [0] * 1536