Skip to content

Commit

Permalink
Fix bad changes (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenConstable9 authored Dec 13, 2024
1 parent 1e2d21f commit 210ee09
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 32 deletions.
26 changes: 12 additions & 14 deletions text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from autogen_agentchat.base import Response
import json
import os
import asyncio
from datetime import datetime


class EmptyResponseUserProxyAgent(UserProxyAgent):
"""UserProxyAgent that automatically responds with empty messages."""

def __init__(self, name):
super().__init__(name=name)
self._has_responded = False
Expand All @@ -35,6 +36,7 @@ async def on_messages_stream(self, messages, sender=None, config=None):
yield message
yield Response(chat_message=message)


class AutoGenText2Sql:
def __init__(self, engine_specific_rules: str, **kwargs: dict):
self.use_query_cache = False
Expand Down Expand Up @@ -65,32 +67,31 @@ def get_all_agents(self):
"""Get all agents for the complete flow."""
# Get current datetime for the Query Rewrite Agent
current_datetime = datetime.now()

QUERY_REWRITE_AGENT = LLMAgentCreator.create(
"query_rewrite_agent",
current_datetime=current_datetime
"query_rewrite_agent", current_datetime=current_datetime
)

SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create(
"sql_query_generation_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent(
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create(
"sql_query_correction_agent",
target_engine=self.target_engine,
engine_specific_rules=self.engine_specific_rules,
**self.kwargs,
)

SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create(
"sql_disambiguation_agent",
target_engine=self.target_engine,
Expand All @@ -101,11 +102,9 @@ def get_all_agents(self):
QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create(
"question_decomposition_agent"
)

# Auto-responding UserProxyAgent
USER_PROXY = EmptyResponseUserProxyAgent(
name="user_proxy"
)
USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy")

agents = [
USER_PROXY,
Expand All @@ -114,7 +113,7 @@ def get_all_agents(self):
SQL_SCHEMA_SELECTION_AGENT,
SQL_QUERY_CORRECTION_AGENT,
QUESTION_DECOMPOSITION_AGENT,
SQL_DISAMBIGUATION_AGENT
SQL_DISAMBIGUATION_AGENT,
]

if self.use_query_cache:
Expand Down Expand Up @@ -192,7 +191,6 @@ def agentic_flow(self):
allow_repeated_speaker=False,
model_client=LLMModelCreator.get_model("4o-mini"),
termination_condition=self.termination_condition,
selector_func=self.selector,
selector_func=self.unified_selector,
)
return flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ async def on_messages_stream(
# Initialize results dictionary
cached_results = {
"cached_questions_and_schemas": [],
"contains_pre_run_results": False
"contains_pre_run_results": False,
}

# Process each question sequentially
for question in user_questions:
# Fetch the queries from the cache based on the question
logging.info(f"Fetching queries from cache for question: {question}")
cached_query = await self.sql_connector.fetch_queries_from_cache(question)

cached_query = await self.sql_connector.fetch_queries_from_cache(
question
)

# If any question has pre-run results, set the flag
if cached_query.get("contains_pre_run_results", False):
cached_results["contains_pre_run_results"] = True

# Add the cached results for this question
if cached_query.get("cached_questions_and_schemas"):
cached_results["cached_questions_and_schemas"].extend(
Expand All @@ -75,7 +77,9 @@ async def on_messages_stream(
except json.JSONDecodeError:
# If not JSON array, process as single question
logging.info(f"Processing single question: {last_response}")
cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response)
cached_queries = await self.sql_connector.fetch_queries_from_cache(
last_response
)
yield Response(
chat_message=TextMessage(
content=json.dumps(cached_queries), source=self.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,21 +198,19 @@ async def get_entity_schemas(
logging.info("Search Text: %s", text)

retrieval_fields = [
# "FQN",
"FQN",
"Entity",
"EntityName",
# "Schema",
# "Definition",
"Description",
"Schema",
"Definition",
"Columns",
"EntityRelationships",
"CompleteEntityRelationshipsGraph",
] + engine_specific_fields

schemas = await self.run_ai_search_query(
text,
# ["DefinitionEmbedding"],
["DescriptionEmbedding"],
["DefinitionEmbedding"],
retrieval_fields,
os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"],
os.environ[
Expand All @@ -227,7 +225,7 @@ async def get_entity_schemas(
for schema in schemas:
filtered_schemas = []

# del schema["FQN"]
del schema["FQN"]

if (
schema["CompleteEntityRelationshipsGraph"] is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,12 @@ async def get_entity_schemas(
)

for schema in schemas:
# schema["SelectFromEntity"] = ".".join(
# [schema["Catalog"], schema["Schema"], schema["Entity"]]
# )
schema["SelectFromEntity"] = schema["Entity"]
schema["SelectFromEntity"] = ".".join(
[schema["Catalog"], schema["Schema"], schema["Entity"]]
)

del schema["Entity"]
# del schema["Schema"]
del schema["Schema"]
del schema["Catalog"]

if as_json:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ system_message:
- Use the current date/time above as reference point
- Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates
- Maintain consistency in date formats (YYYY-MM-DD)
Examples of date resolution (assuming current date is {{ current_datetime }}):
- 'last month' -> specific month name and year
- 'this year' -> {{ current_datetime.year }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ system_message:
<unsuccessful_mapping_entry>
- If you cannot map it to a column, add en entry to the disambiguation list with the clarification question you need from the user:
- If there are multiple possible options, or you are unsure how it maps, make sure to ask a clarification question.
- If there are no possible options, ask a clarification question for more detail.
{
\"disambiguation\": [
Expand Down

0 comments on commit 210ee09

Please sign in to comment.