Skip to content

Commit

Permalink
Update prompts and agents to support programmatic sources (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenConstable9 authored Dec 17, 2024
1 parent 0e11c29 commit 46b4cf8
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 38 deletions.
31 changes: 19 additions & 12 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from autogen_agentchat.conditions import (
TextMentionTermination,
MaxMessageTermination,
SourceMatchTermination,
)
from autogen_agentchat.teams import SelectorGroupChat
from autogen_text_2_sql.creators.llm_model_creator import LLMModelCreator
Expand All @@ -13,6 +12,9 @@
from autogen_text_2_sql.custom_agents.sql_schema_selection_agent import (
SqlSchemaSelectionAgent,
)
from autogen_text_2_sql.custom_agents.answer_and_sources_agent import (
AnswerAndSourcesAgent,
)
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.messages import TextMessage
from autogen_agentchat.base import Response
Expand Down Expand Up @@ -99,6 +101,8 @@ def get_all_agents(self):
**self.kwargs,
)

self.answer_and_sources_agent = AnswerAndSourcesAgent()

# Auto-responding UserProxyAgent
self.user_proxy = EmptyResponseUserProxyAgent(name="user_proxy")

Expand All @@ -109,6 +113,7 @@ def get_all_agents(self):
self.sql_schema_selection_agent,
self.sql_query_correction_agent,
self.sql_disambiguation_agent,
self.answer_and_sources_agent,
]

if self.use_query_cache:
Expand All @@ -122,11 +127,7 @@ def termination_condition(self):
"""Define the termination condition for the chat."""
termination = (
TextMentionTermination("TERMINATE")
| (
TextMentionTermination("answer")
& TextMentionTermination("sources")
& SourceMatchTermination("sql_query_correction_agent")
)
| (TextMentionTermination("answer") & TextMentionTermination("sources"))
| MaxMessageTermination(20)
)
return termination
Expand Down Expand Up @@ -166,14 +167,20 @@ def unified_selector(self, messages):
decision = "sql_query_generation_agent"

elif messages[-1].source == "sql_query_correction_agent":
decision = "sql_query_generation_agent"
if "answer" in messages[-1].content is not None:
decision = "answer_and_sources_agent"
else:
decision = "sql_query_generation_agent"

elif messages[-1].source == "sql_query_generation_agent":
decision = "sql_query_correction_agent"
elif messages[-1].source == "sql_query_correction_agent":
decision = "sql_query_correction_agent"
elif messages[-1].source == "answer_agent":
return "user_proxy" # Let user_proxy send TERMINATE
if "query_execution_with_limit" in messages[-1].content:
decision = "sql_query_correction_agent"
else:
# Rerun
decision = "sql_query_generation_agent"

elif messages[-1].source == "answer_and_sources_agent":
decision = "user_proxy" # Let user_proxy send TERMINATE

logging.info("Decision: %s", decision)
return decision
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from autogen_core.components.tools import FunctionTool
from autogen_core.components.tools import FunctionToolAlias
from autogen_agentchat.agents import AssistantAgent
from text_2_sql_core.connectors.factory import ConnectorFactory
from text_2_sql_core.prompts.load import load
Expand Down Expand Up @@ -32,25 +32,25 @@ def get_tool(cls, sql_helper, ai_search_helper, tool_name: str):
tool_name (str): The name of the tool to retrieve.
Returns:
FunctionTool: The tool."""
FunctionToolAlias: The tool."""

if tool_name == "sql_query_execution_tool":
return FunctionTool(
return FunctionToolAlias(
sql_helper.query_execution_with_limit,
description="Runs an SQL query against the SQL Database to extract information",
)
elif tool_name == "sql_get_entity_schemas_tool":
return FunctionTool(
return FunctionToolAlias(
sql_helper.get_entity_schemas,
description="Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Extract key terms from the user question and use these as the search term. Several entities may be returned. Only use when the provided schemas in the system prompt are not sufficient to answer the question.",
)
elif tool_name == "sql_get_column_values_tool":
return FunctionTool(
return FunctionToolAlias(
ai_search_helper.get_column_values,
description="Gets the values of a column in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned. Use this to get the correct value to apply against a filter for a user's question.",
)
elif tool_name == "current_datetime_tool":
return FunctionTool(
return FunctionToolAlias(
sql_helper.get_current_datetime,
description="Gets the current date and time.",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import AsyncGenerator, List, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentMessage, ChatMessage, TextMessage
from autogen_core import CancellationToken
import json
from json import JSONDecodeError
import logging
import pandas as pd


class AnswerAndSourcesAgent(BaseChatAgent):
def __init__(self):
super().__init__(
"answer_and_sources_agent",
"An agent that formats the final answer and sources.",
)

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage]

async def on_messages(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> Response:
# Calls the on_messages_stream.
response: Response | None = None
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
response = message
assert response is not None
return response

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
last_response = messages[-1].content

# Load the json of the last message to populate the final output object
final_output_object = json.loads(last_response)
final_output_object["sources"] = []

for message in messages:
# Load the message content if it is a json object and was a query execution
try:
message = json.loads(message.content)
logging.info(f"Loaded: {message}")

# Search for specific message types and add them to the final output object
if (
"type" in message
and message["type"] == "query_execution_with_limit"
):
dataframe = pd.DataFrame(message["sql_rows"])
final_output_object["sources"].append(
{
"sql_query": message["sql_query"].replace("\n", " "),
"sql_rows": message["sql_rows"],
"markdown_table": dataframe.to_markdown(index=False),
}
)

except JSONDecodeError:
logging.info(f"Could not load message: {message}")
continue

except Exception as e:
logging.error(f"Error processing message: {e}")
raise e

yield Response(
chat_message=TextMessage(
content=json.dumps(final_output_object), source=self.name
)
)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
2 changes: 2 additions & 0 deletions text_2_sql/text_2_sql_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ dependencies = [
"networkx>=3.4.2",
"numpy<2.0.0",
"openai>=1.55.3",
"pandas>=2.2.3",
"pydantic>=2.10.2",
"python-dotenv>=1.0.1",
"pyyaml>=6.0.2",
"rich>=13.9.4",
"sqlglot[rs]>=25.32.1",
"tabulate>=0.9.0",
"tenacity>=9.0.0",
"typer>=0.14.0",
]
Expand Down
19 changes: 17 additions & 2 deletions text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from jinja2 import Template
import json


class SqlConnector(ABC):
Expand Down Expand Up @@ -109,9 +110,23 @@ async def query_execution_with_limit(
validation_result = await self.query_validation(sql_query)

if isinstance(validation_result, bool) and validation_result:
return await self.query_execution(sql_query, cast_to=None, limit=25)
result = await self.query_execution(sql_query, cast_to=None, limit=25)

return json.dumps(
{
"type": "query_execution_with_limit",
"sql_query": sql_query,
"sql_rows": result,
}
)
else:
return validation_result
return json.dumps(
{
"type": "errored_query_execution_with_limit",
"sql_query": sql_query,
"errors": validation_result,
}
)

async def query_validation(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,25 @@ system_message:
<output_format>
- **If the SQL query is valid and the results are correct**:
```json
{
\"answer\": \"<GENERATED ANSWER>\",
\"sources\": [
{
\"sql_result_snippet\": \"<SQL QUERY RESULT 1>\",
\"sql_query_used\": \"<SOURCE 1 SQL QUERY>\",
\"explanation\": \"<EXPLANATION OF SQL QUERY 1>\"
},
{
\"sql_result_snippet\": \"<SQL QUERY RESULT 2>\",
\"sql_query_used\": \"<SOURCE 2 SQL QUERY>\",
\"explanation\": \"<EXPLANATION OF SQL QUERY 2>\"
}
]
}
```
- **If the SQL query needs corrections**:
```json
[
{
\"requested_fix\": \"<EXPLANATION OF REQUESTED FIX OF THE SQL QUERY>\"
}
]
```
- **If the SQL query cannot be corrected**:
```json
{
\"error\": \"Unable to correct the SQL query. Please request a new SQL query.\"
}
```
Followed by **TERMINATE**.
</output_format>
"

0 comments on commit 46b4cf8

Please sign in to comment.