Skip to content

Commit

Permalink
Adds session injection function for better session management in stan…
Browse files Browse the repository at this point in the history
…d-alone library
  • Loading branch information
vmesel committed Jul 3, 2024
1 parent eb9e41b commit 554bc9e
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,5 @@ poetry.toml
pyrightconfig.json

# End of https://www.toptal.com/developers/gitignore/api/python
n
n
.envrc
3 changes: 2 additions & 1 deletion dialog_lib/agents/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains.conversation.memory import ConversationBufferMemory

from dialog_lib.db import get_session
from dialog_lib.db.memory import CustomPostgresChatMessageHistory, get_memory_instance
from dialog_lib.embeddings.retrievers import DialogRetriever

Expand All @@ -24,7 +25,7 @@ def __init__(
parent_session_id=None,
dataset=None,
llm_api_key=None,
dbsession=None,
dbsession=get_session(),
):
"""
:param config: Configuration dictionary
Expand Down
1 change: 1 addition & 0 deletions dialog_lib/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
add_user_message_to_message_history,
get_messages,
)
from .session import get_session
10 changes: 5 additions & 5 deletions dialog_lib/db/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import psycopg

from .session import get_session
from langchain_postgres import PostgresChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict

Expand All @@ -15,7 +15,7 @@ def __init__(
self,
*args,
parent_session_id=None,
dbsession=None,
dbsession=get_session(),
chats_model=Chat,
chat_messages_model=ChatMessages,
ssl_mode=None,
Expand Down Expand Up @@ -67,7 +67,7 @@ def add_message(self, message: BaseMessage) -> None:
def generate_memory_instance(
session_id,
parent_session_id=None,
dbsession=None,
dbsession=get_session(),
database_url=None,
chats_model=Chat,
chat_messages_model=ChatMessages,
Expand All @@ -88,7 +88,7 @@ def generate_memory_instance(


def add_user_message_to_message_history(
session_id, message, memory=None, dbsession=None, database_url=None
session_id, message, memory=None, dbsession=get_session(), database_url=None
):
"""
Add a user message to the message history and returns the updated
Expand All @@ -103,7 +103,7 @@ def add_user_message_to_message_history(
return memory


def get_messages(session_id, dbsession=None, database_url=None):
def get_messages(session_id, dbsession=get_session(), database_url=None):
"""
Get all messages for a given session_id
"""
Expand Down
10 changes: 10 additions & 0 deletions dialog_lib/db/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

import sqlalchemy as sa
from sqlalchemy.orm import Session

def get_session():
engine = sa.create_engine(os.environ.get("DATABASE_URL"))
session = Session(engine)
yield session
session.close()
4 changes: 2 additions & 2 deletions dialog_lib/db/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import uuid

from .session import get_session
from .models import Chat


def create_chat_session(identifier=None, dbsession=None, model=Chat):
def create_chat_session(identifier=None, dbsession=get_session(), model=Chat):
if identifier is None:
identifier = uuid.uuid4().hex

Expand Down
4 changes: 3 additions & 1 deletion dialog_lib/loaders/csv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dialog_lib.db import get_session
from dialog_lib.db.models import CompanyContent
from dialog_lib.embeddings.generate import generate_embedding

Expand All @@ -6,9 +7,10 @@


def load_csv(
file_path, dbsession, embeddings_model_instance=None,
file_path, dbsession=get_session(), embeddings_model_instance=None,
embedding_llm_model=None, embedding_llm_api_key=None, company_id=None
):

loader = CSVLoader(file_path=file_path)
contents = loader.load()

Expand Down
4 changes: 2 additions & 2 deletions dialog_lib/loaders/web.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dialog_lib.db.models import CompanyContent
from dialog_lib.embeddings.generate import generate_embedding

from dialog_lib.db import get_session
from langchain_community.document_loaders import WebBaseLoader


def load_webpage(url, dbsession, embeddings_model_instance, company_id=None):
def load_webpage(url, embeddings_model_instance, dbsession=get_session(), company_id=None):
loader = WebBaseLoader(url)
contents = loader.load()

Expand Down
2 changes: 1 addition & 1 deletion dialog_lib/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def anthropic(model, temperature, llm_api_key, prompt, debug):
@click.option("--llm-api-key", default=get_llm_key(), help="The LLM API key", required=True)
@click.option("--file", help="The CSV file to load the data from", required=True)
def load_csv(database_url, llm_api_key, file):
breakpoint()
engine = create_engine(database_url)
dbsession = Session(engine.connect())
csv_loader(
Expand All @@ -86,6 +85,7 @@ def load_csv(database_url, llm_api_key, file):
embedding_llm_api_key=llm_api_key
)
click.echo("## Loaded the CSV file to the database")
dbsession.close()


def main():
Expand Down
1 change: 0 additions & 1 deletion dialog_lib/tests/agents/test_abstract_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_abstract_agent_with_valid_config():
assert agent.dataset is None
assert agent.llm_api_key is None
assert agent.parent_session_id is None
assert agent.dbsession is None

def test_abstract_agent_get_prompt():
config = {
Expand Down
2 changes: 1 addition & 1 deletion dialog_lib/tests/loaders/test_web_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_load_web_content(mock_aioresponse, db_session, mocker):
mocker.patch('dialog_lib.loaders.web.generate_embedding', return_value=[0] * 1536)
mock_aioresponse.get('http://example.com', body='Hello, world!')

content = load_webpage('http://example.com', db_session, None, 1)
content = load_webpage('http://example.com', None, db_session, 1)
assert content.question == "Example Domain"
assert content.embedding == [0] * 1536

Expand Down

0 comments on commit 554bc9e

Please sign in to comment.