Skip to content

Commit

Permalink
Vaana AI Integration
Browse files Browse the repository at this point in the history
  • Loading branch information
sarath-nalluri committed Nov 14, 2024
1 parent 9ba2dd5 commit d286320
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 237 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Unstructured Data: Ingests product manuals, product catalogs, and FAQ
This reference solution implements [different sub-agents using the open-source LangGraph framework and a supervisor agent to orchestrate the entire flow.](./src/agent/) These sub-agents address common customer service tasks for the included sample dataset. They rely on the Llama 3.1 models and NVIDIA NIM microservices for generating responses, converting natural language into SQL queries, and assessing the sentiment of the conversation.

## Key Components
* [**Structured Data Retriever**](./src/retrievers/structured_data/): Works in tandem with a Postgres database and PandasAI to fetch relevant data based on user queries.
* [**Structured Data Retriever**](./src/retrievers/structured_data/): Works in tandem with a Postgres database and Vanna.AI to fetch relevant data based on user queries.
* [**Unstructured Data Retriever**](./src/retrievers/unstructured_data/): Processes unstructured data (e.g., PDFs, FAQs) by chunking it, creating embeddings using the NeMo Retriever embedding NIM, and storing it in Milvus for fast retrieval.
* [**Analytics and Admin Operations**](./src/analytics/): To support operational requirements, the blueprint includes reference code and APIs for managing key administrative tasks
* Storing conversation histories
Expand Down
20 changes: 17 additions & 3 deletions deploy/compose/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,15 @@ services:
APP_LLM_MODELNAME: ${APP_LLM_MODELNAME:-meta/llama-3.1-70b-instruct}
APP_LLM_MODELENGINE: nvidia-ai-endpoints
APP_LLM_SERVERURL: ${APP_LLM_SERVERURL:-""}
APP_LLM_MODELNAMEPANDASAI: ${APP_LLM_MODELNAME:-meta/llama-3.1-70b-instruct}
APP_EMBEDDINGS_MODELNAME: ${APP_EMBEDDINGS_MODELNAME:-nvidia/nv-embedqa-e5-v5}
APP_EMBEDDINGS_MODELENGINE: ${APP_EMBEDDINGS_MODELENGINE:-nvidia-ai-endpoints}
APP_EMBEDDINGS_SERVERURL: ${APP_EMBEDDINGS_SERVERURL:-""}
APP_PROMPTS_CHATTEMPLATE: "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are positive in nature."
APP_PROMPTS_RAGTEMPLATE: "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."
NVIDIA_API_KEY: ${NVIDIA_API_KEY}
COLLECTION_NAME: ${COLLECTION_NAME:-structured_data}
APP_VECTORSTORE_URL: "http://milvus:19530"
APP_VECTORSTORE_NAME: "milvus"
# Database name to store user purchase history, only postgres is supported
APP_DATABASE_NAME: ${APP_DATABASE_NAME:-"postgres"}
APP_DATABASE_URL: ${APP_DATABASE_URL:-"postgres:5432"}
Expand All @@ -216,6 +220,8 @@ services:
postgres:
condition: service_healthy
required: false
milvus:
condition: service_healthy


# =======================
Expand Down Expand Up @@ -333,9 +339,17 @@ services:
ports:
- "19530:19530"
- "9091:9091"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
interval: 30s
start_period: 90s
timeout: 20s
retries: 3
depends_on:
- "etcd"
- "minio"
etcd:
condition: service_healthy
minio:
condition: service_healthy
deploy:
resources:
reservations:
Expand Down
10 changes: 5 additions & 5 deletions src/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,15 @@ def get_llm(**kwargs) -> LLM | SimpleChatModel:
@lru_cache
def get_embedding_model() -> Embeddings:
"""Create the embedding model."""
model_kwargs = {"device": "cpu"}
if torch.cuda.is_available():
model_kwargs["device"] = "cuda:0"

encode_kwargs = {"normalize_embeddings": False}
settings = get_config()

logger.info(f"Using {settings.embeddings.model_engine} as model engine and {settings.embeddings.model_name} and model for embeddings")
if settings.embeddings.model_engine == "huggingface":
model_kwargs = {"device": "cpu"}
if torch.cuda.is_available():
model_kwargs["device"] = "cuda:0"

encode_kwargs = {"normalize_embeddings": False}
hf_embeddings = HuggingFaceEmbeddings(
model_name=settings.embeddings.model_name,
model_kwargs=model_kwargs,
Expand Down
85 changes: 28 additions & 57 deletions src/retrievers/structured_data/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,41 @@

""" Retriever pipeline for extracting data from structured information"""
import logging
import os
from typing import Any, Dict, List

from pandasai import Agent as PandasAI_Agent
from pandasai.responses.response_parser import ResponseParser
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from src.retrievers.structured_data.connector import get_postgres_connector
from urllib.parse import urlparse
import pandas as pd
from src.retrievers.structured_data.vaanaai.vaana_base import VannaWrapper
from src.retrievers.base import BaseExample
from src.common.utils import get_config, get_prompts
from src.retrievers.structured_data.pandasai.llms.nv_aiplay import NVIDIA as PandasAI_NVIDIA
from src.common.utils import get_config

logger = logging.getLogger(__name__)
settings = get_config()

# Load the vaana_client
vaana_client = VannaWrapper()
# Connect to the Postgress DB
app_database_url = get_config().database.url

class PandasDataFrame(ResponseParser):
"""Returns Pandas Dataframe instead of SmartDataFrame"""

def __init__(self, context) -> None:
super().__init__(context)
# Parse the URL
parsed_url = urlparse(f"//{app_database_url}", scheme='postgres')

def format_dataframe(self, result):
return result["value"]
# Extract host and port
host = parsed_url.hostname
port = parsed_url.port

vaana_client.connect_to_postgres(
host=parsed_url.hostname,
dbname=os.getenv("POSTGRES_DB",'customer_data'),
user=os.getenv('POSTGRES_USER', 'postgres'),
password= os.getenv('POSTGRES_PASSWORD', 'password'),
port=parsed_url.port
)
# Do Training
vaana_client.do_training()

class CSVChatbot(BaseExample):
"""RAG example showcasing CSV parsing using Pandas AI Agent"""
"""RAG example showcasing CSV parsing using Vaana AI Agent"""

def ingest_docs(self, filepath: str, filename: str):
"""Ingest documents to the VectorDB."""
Expand All @@ -55,61 +60,27 @@ def document_search(self, content: str, num_docs: int, user_id: str = None, conv
"""Execute a Document Search."""

logger.info("Using document_search to fetch response from database as text")
postgres_connector = None # Initialize connector

try:
logger.info("Using document_search to fetch response from database as text")
if user_id:
postgres_connector = get_postgres_connector(user_id)
pass
else:
logger.warning("Enter a proper User ID")
return [{"content": "No response generated, make to give a proper User ID."}]

# TODO: Pass conv history to the LLM
llm_data_retrieval = PandasAI_NVIDIA(temperature=0.2, model=settings.llm.model_name_pandas_ai)

config_data_retrieval = {"llm": llm_data_retrieval, "response_parser": PandasDataFrame, "max_retries": 1, "enable_cache": False}
agent_data_retrieval = PandasAI_Agent([postgres_connector], config=config_data_retrieval, memory_size=20)

prompt_config = get_prompts().get("prompts")

data_retrieval_prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(prompt_config.get("data_retrieval_template", [])),
HumanMessagePromptTemplate.from_template("{query}"),
],
input_variables=["description", "instructions", "query"],
)


chat_prompt = data_retrieval_prompt.format_prompt(
description=prompt_config.get("dataframe_prompts").get("customer_data").get("description"),
instructions=prompt_config.get("dataframe_prompts").get("customer_data").get("instructions"),
query=content,
).to_string()
result_df = vaana_client.ask_query(question=content, user_id=user_id)

result_df = agent_data_retrieval.chat(
chat_prompt
)
logger.info("Result Data Frame: %s", result_df)
if not result_df:
if (isinstance(result_df, pd.DataFrame) and result_df.empty) or (isinstance(result_df, str) and result_df == "not valid sql") or (result_df is None):
logger.warning("Retrieval failed to get any relevant context")
return [{"content": "No response generated from LLM, make sure your query is relavent to the ingested document."}]
raise Exception("No response generated from LLM. Make sure your query is relevant to the ingested document.")

result_df = str(result_df)
return [{"content": result_df}]
except Exception as e:
logger.error("An error occurred during document search: %s", str(e))
raise # Re-raise the exception after logging

finally:
if postgres_connector:
postgres_connector._connection._dbapi_connection.close()
postgres_connector._connection.close()
postgres_connector._engine.dispose()
import gc
gc.collect()
logger.info("Postgres connector deleted.")

def get_documents(self) -> List[str]:
"""Retrieves filenames stored in the vector store."""
Expand Down
43 changes: 0 additions & 43 deletions src/retrievers/structured_data/connector.py

This file was deleted.

14 changes: 0 additions & 14 deletions src/retrievers/structured_data/pandasai/llms/__init__.py

This file was deleted.

111 changes: 0 additions & 111 deletions src/retrievers/structured_data/pandasai/llms/nv_aiplay.py

This file was deleted.

5 changes: 2 additions & 3 deletions src/retrievers/structured_data/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ python-multipart==0.0.9
langchain==0.2.16
langchain-nvidia-ai-endpoints==0.2.2
dataclass-wizard==0.22.3
pandas==1.5.3
pandasai==2.2.14
numexpr==2.9.0
psycopg2-binary==2.9.9
psycopg2-binary==2.9.9
vanna[postgres,milvus]==0.7.5
Loading

0 comments on commit d286320

Please sign in to comment.