Skip to content

Commit

Permalink
Adds LCEL class and adds a sample class for OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
vmesel committed Jun 23, 2024
1 parent f301478 commit bd9b54b
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 15 deletions.
108 changes: 98 additions & 10 deletions dialog_lib/agents/abstract.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
import warnings
from operator import itemgetter

from langchain.schema import format_document
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import PromptTemplate
from langchain.memory.chat_memory import BaseChatMemory
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains.conversation.memory import ConversationBufferMemory

from dialog_lib.db.memory import CustomPostgresChatMessageHistory, get_memory_instance
from dialog_lib.embeddings.retrievers import DialogRetriever


class AbstractLLM:
Expand Down Expand Up @@ -104,45 +112,81 @@ def messages(self):
return self.memory.messages


class AbstractLcelClass(AbstractLLM):
class AbstractLCEL(AbstractLLM):

@property
def document_prompt(self):
return PromptTemplate.from_template(template="{page_content}")

@property
def chat_model(self):
def model(self):
"""
builds and returns the chat model for the LCEL
"""
raise NotImplementedError("Chat model must be implemented")

@property
def retriever(self):
"""
builds and returns the retriever for the LCEL
"""
raise NotImplementedError("Retriever must be implemented")

def documents_formatter(self, docs, document_separator="\n\n"):
"""
This is the default combine_documents function that returns the documents as is.
We use the default format_documents function from Langchain.
"""
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
return document_separator.join(doc_strings)

@property
def context_dict(self):
"""
builds and returns the context dictionary for the LCEL
"""
raise NotImplementedError("Context dictionary must be implemented")

context_dict = {
"input": RunnablePassthrough(),
"chat_history": itemgetter("chat_history"),
}

if self.retriever:
context_dict["context"] = itemgetter("input") | self.retriever | self.documents_formatter

return context_dict

@property
def chain(self):
"""
builds and returns the chain for the LCEL
"""
return (
self.context_dict | self.prompt | self.chat_model
self.context_dict | self.prompt | self.model
)

@property
def get_memory_instance(self):
def memory(self):
return get_memory_instance(
session_id=self.session_id,
sqlalchemy_session=self.dbsession,
database_url=self.config.get("database_url")
)

def get_session_history(self, something):
return CustomPostgresChatMessageHistory(
connection_string=self.config.get("database_url"),
session_id=self.session_id,
parent_session_id=self.parent_session_id,
table_name="chat_messages",
dbsession=self.dbsession,
)

@property
def runnable(self):
RunnableWithMessageHistory(
return RunnableWithMessageHistory(
self.chain,
self.get_memory_instance,
self.get_session_history,
input_messages_key='input',
history_messages_key="chat_history"
)
Expand Down Expand Up @@ -201,6 +245,13 @@ def process(self, input: str):

class AbstractDialog(AbstractLLM):
def __init__(self, *args, **kwargs):
warnings.filterwarnings("default", category=DeprecationWarning)
warnings.warn(
(
"AbstractDialog will be deprecated in release 0.2 due to the creation of Langchain's LCEL. ",
"Please use AbstractLCELDialog instead."
), DeprecationWarning, stacklevel=3
)
kwargs["config"] = kwargs.get("config", {})

self.memory_instance = kwargs.pop("memory", None)
Expand Down Expand Up @@ -244,4 +295,41 @@ def llm(self):
)
return LLMChain(
**chain_settings
)
)


class AbstractLCELDialog(AbstractLCEL):
def __init__(self, *args, **kwargs):
kwargs["config"] = kwargs.get("config", {})

self.memory_instance = kwargs.pop("memory", None)
self.llm_api_key = kwargs
self.prompt_content = kwargs.pop("prompt", None)
self.chat_model = kwargs.pop("model_class")
self.embedding_llm = kwargs.pop("embedding_llm")
super().__init__(*args, **kwargs)

@property
def retriever(self):
return DialogRetriever(
session=self.dbsession,
embedding_llm=self.embedding_llm,
)

@property
def model(self):
return self.chat_model

def generate_prompt(self, input_text):
self.prompt = ChatPromptTemplate.from_messages(
[
("system", "What can I help you with today?"),
MessagesPlaceholder(variable_name="chat_history"),
("system", "Here is some context for the user request: {context}"),
("human", "{input}"),
]
)
return input_text

def postprocess(self, output):
return output.content
18 changes: 16 additions & 2 deletions dialog_lib/agents/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .abstract import AbstractDialog
from .abstract import AbstractDialog, AbstractLCELDialog
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models.base import ChatOpenAI
from dialog_lib.embeddings.retrievers import DialogRetriever


class DialogOpenAI(AbstractDialog):
Expand All @@ -14,4 +16,16 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def postprocess(self, output):
return output.get("text")
return output.get("text")


class DialogLCELOpenAI(AbstractLCELDialog):
def __init__(self, *args, **kwargs):
self.openai_api_key = kwargs.get("llm_api_key") or os.environ.get("OPENAI_API_KEY")
kwargs["model_class"] = ChatOpenAI(
model=kwargs.pop("model"),
temperature=kwargs.pop("temperature"),
openai_api_key=self.openai_api_key,
)
kwargs["embedding_llm"] = OpenAIEmbeddings(openai_api_key=self.openai_api_key)
super().__init__(*args, **kwargs)
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dialog-lib"
version = "0.0.1.19"
version = "0.0.1.20"
description = ""
authors = ["Talkd.AI <[email protected]>"]
license = "MIT"
Expand All @@ -14,7 +14,7 @@ click = "^8.1.7"
pgvector = "^0.2.5"
langchain-openai = "^0.1.8"
psycopg2-binary = "^2.9.9"
langchain-postgres = "^0.0.7"
langchain-postgres = "0.0.7"
langchain-community = "^0.2.5"
langchain-anthropic = "^0.1.11"
bs4 = "^0.0.2"
Expand Down
32 changes: 32 additions & 0 deletions samples/openai/lcel/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import logging
from uuid import uuid4
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from dialog_lib.agents.openai import DialogLCELOpenAI

logging.getLogger().setLevel(logging.ERROR)

database_url = "postgresql://talkdai:talkdai@db:5432/test_talkdai"

engine = create_engine(database_url)

dbsession = Session(engine)


agent = DialogLCELOpenAI(
model="gpt-4o",
temperature=0.1,
llm_api_key=os.environ.get("OPENAI_API_KEY"),
prompt="You are a bot called Sara. Be nice to other human beings.",
dbsession=dbsession,
config={
"database_url": database_url,
},
session_id=str(uuid4())
)

while True:
input_text = input("You: ")
output_text = agent.process(input_text)
print(f"Sara: {output_text}")

0 comments on commit bd9b54b

Please sign in to comment.