Skip to content

Commit

Permalink
Parameter Cache Rendering (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenConstable9 authored Dec 17, 2024
1 parent cc3d8cc commit 1890062
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"metadata": {},
"outputs": [],
"source": [
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"Use TOP X to limit the number of rows returned instead of LIMIT X. NEVER USE LIMIT X as it produces a syntax error.\", use_case=\"Analysing sales data across product categories.\").agentic_flow"
"agentic_text_2_sql = AutoGenText2Sql(engine_specific_rules=\"\", use_case=\"Analysing sales data across suppliers\")"
]
},
{
Expand All @@ -101,56 +101,8 @@
"metadata": {},
"outputs": [],
"source": [
"result = agentic_text_2_sql.run_stream(task=\"What country did we sell the most to in June 2008?\")\n",
"await Console(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await agentic_text_2_sql.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008 for the mountain bike category?\")\n",
"await Console(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"await agentic_text_2_sql.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result = agentic_text_2_sql.run_stream(task=\"What are the total number of sales within 2008?\")\n",
"await Console(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"await agentic_text_2_sql.reset()"
"result = await agentic_text_2_sql.process_question(task=\"What total number of orders in June 2008?\")\n",
"await Console(result)\n"
]
},
{
Expand Down
61 changes: 36 additions & 25 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,57 +68,52 @@ def get_all_agents(self):
# Get current datetime for the Query Rewrite Agent
current_datetime = datetime.now()

QUERY_REWRITE_AGENT = LLMAgentCreator.create(
self.query_rewrite_agent = LLMAgentCreator.create(
"query_rewrite_agent", current_datetime=current_datetime
)

SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
self.sql_query_generation_agent = LLMAgentCreator.create(
"sql_query_generation_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
self.sql_schema_selection_agent = SqlSchemaSelectionAgent(
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
self.sql_query_correction_agent = LLMAgentCreator.create(
"sql_query_correction_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
self.sql_disambiguation_agent = LLMAgentCreator.create(
"sql_disambiguation_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
"question_decomposition_agent"
)

# Auto-responding UserProxyAgent
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")

agents = [
USER_PROXY,
QUERY_REWRITE_AGENT,
SQL_QUERY_GENERATION_AGENT,
SQL_SCHEMA_SELECTION_AGENT,
SQL_QUERY_CORRECTION_AGENT,
QUESTION_DECOMPOSITION_AGENT,
SQL_DISAMBIGUATION_AGENT,
self.user_proxy,
self.query_rewrite_agent,
self.sql_query_generation_agent,
self.sql_schema_selection_agent,
self.sql_query_correction_agent,
self.sql_disambiguation_agent,
]

if self.use_query_cache:
SQL_QUERY_CACHE_AGENT = SqlQueryCacheAgent()
agents.append(SQL_QUERY_CACHE_AGENT)
self.query_cache_agent = SqlQueryCacheAgent()
agents.append(self.query_cache_agent)

return agents

Expand All @@ -136,7 +131,7 @@ def termination_condition(self):
)
return termination

def unified_selector(messages):
def unified_selector(self, messages):
"""Unified selector for the complete flow."""
logging.info("Messages: %s", messages)
decision = None
Expand Down Expand Up @@ -195,18 +190,34 @@ def agentic_flow(self):
)
return flow

async def process_question(self, task: str, chat_history: list[str] = None):
"""Process the complete question through the unified system."""
async def process_question(
self, task: str, chat_history: list[str] = None, parameters: dict = None
):
"""Process the complete question through the unified system.
Args:
----
task (str): The user question to process.
chat_history (list[str], optional): The chat history. Defaults to None.
parameters (dict, optional): The parameters to pass to the agents. Defaults to None.
Returns:
-------
dict: The response from the system.
"""

logging.info("Processing question: %s", task)
logging.info("Chat history: %s", chat_history)

agent_input = {"user_question": task, "chat_history": {}}
agent_input = {
"user_question": task,
"chat_history": {},
"parameters": parameters,
}

if chat_history is not None:
# Update input
for idx, chat in enumerate(chat_history):
agent_input[f"chat_{idx}"] = chat

result = await self.agentic_flow.run_stream(task=json.dumps(agent_input))
return result
return self.agentic_flow.run_stream(task=json.dumps(agent_input))
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


class SqlQueryCacheAgent(BaseChatAgent):
def __init__(self, name: str = "sql_query_cache_agent"):
def __init__(self):
super().__init__(
name,
"sql_query_cache_agent",
"An agent that fetches the queries from the cache based on the user question.",
)

Expand All @@ -39,10 +39,13 @@ async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Get the decomposed questions from the query_rewrite_agent
parameter_input = messages[0].content
last_response = messages[-1].content
try:
user_questions = json.loads(last_response)
user_parameters = json.loads(parameter_input)["parameters"]
logging.info(f"Processing questions: {user_questions}")
logging.info(f"Input Parameters: {user_parameters}")

# Initialize results dictionary
cached_results = {
Expand All @@ -55,7 +58,7 @@ async def on_messages_stream(
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {question}")
cached_query = await self.sql_connector.fetch_queries_from_cache(
question
question, parameters=user_parameters
)

# If any question has pre-run results, set the flag
Expand Down
42 changes: 41 additions & 1 deletion text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sqlglot
from abc import ABC, abstractmethod
from datetime import datetime
from jinja2 import Template


class SqlConnector(ABC):
Expand All @@ -30,6 +31,18 @@ def get_current_datetime(self) -> str:
"""Get the current datetime."""
return datetime.now().strftime("%d/%m/%Y, %H:%M:%S")

def get_current_date(self) -> str:
"""Get the current date."""
return datetime.now().strftime("%d/%m/%Y")

def get_current_time(self) -> str:
"""Get the current time."""
return datetime.now().strftime("%H:%M:%S")

def get_current_unix_timestamp(self) -> int:
"""Get the current unix timestamp."""
return int(datetime.now().timestamp())

@abstractmethod
async def query_execution(
self,
Expand Down Expand Up @@ -118,7 +131,9 @@ async def query_validation(
logging.info("SQL Query is valid.")
return True

async def fetch_queries_from_cache(self, question: str) -> str:
async def fetch_queries_from_cache(
self, question: str, parameters: dict = None
) -> str:
"""Fetch the queries from the cache based on the question.
Args:
Expand All @@ -129,6 +144,23 @@ async def fetch_queries_from_cache(self, question: str) -> str:
-------
str: The formatted string of the queries fetched from the cache. This is injected into the prompt.
"""

if parameters is None:
parameters = {}

# Populate the parameters
if "date" not in parameters:
parameters["date"] = self.get_current_date()

if "time" not in parameters:
parameters["time"] = self.get_current_time()

if "datetime" not in parameters:
parameters["datetime"] = self.get_current_datetime()

if "unix_timestamp" not in parameters:
parameters["unix_timestamp"] = self.get_current_unix_timestamp()

cached_schemas = await self.ai_search_connector.run_ai_search_query(
question,
["QuestionEmbedding"],
Expand All @@ -146,6 +178,14 @@ async def fetch_queries_from_cache(self, question: str) -> str:
"cached_questions_and_schemas": None,
}

# loop through all sql queries and populate the template in place
for schema in cached_schemas:
sql_queries = schema["SqlQueryDecomposition"]
for sql_query in sql_queries:
sql_query["SqlQuery"] = Template(sql_query["SqlQuery"]).render(
**parameters
)

logging.info("Cached schemas: %s", cached_schemas)
if self.pre_run_query_cache and len(cached_schemas) > 0:
# check the score
Expand Down

0 comments on commit 1890062

Please sign in to comment.