From 210ee09dc92c35e9d9b8d065340a08e52ea54523 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Fri, 13 Dec 2024 13:51:00 +0000 Subject: [PATCH] Fix bad changes (#101) --- .../autogen_text_2_sql/autogen_text_2_sql.py | 26 +++++++++---------- .../custom_agents/sql_query_cache_agent.py | 14 ++++++---- .../text_2_sql_core/connectors/ai_search.py | 12 ++++----- .../connectors/databricks_sql.py | 9 +++---- .../prompts/query_rewrite_agent.yaml | 2 +- .../prompts/sql_disambiguation_agent.yaml | 1 + 6 files changed, 32 insertions(+), 32 deletions(-) diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py index 493e6ce..1047207 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/autogen_text_2_sql.py @@ -18,11 +18,12 @@ from autogen_agentchat.base import Response import json import os -import asyncio from datetime import datetime + class EmptyResponseUserProxyAgent(UserProxyAgent): """UserProxyAgent that automatically responds with empty messages.""" + def __init__(self, name): super().__init__(name=name) self._has_responded = False @@ -35,6 +36,7 @@ async def on_messages_stream(self, messages, sender=None, config=None): yield message yield Response(chat_message=message) + class AutoGenText2Sql: def __init__(self, engine_specific_rules: str, **kwargs: dict): self.use_query_cache = False @@ -65,32 +67,31 @@ def get_all_agents(self): """Get all agents for the complete flow.""" # Get current datetime for the Query Rewrite Agent current_datetime = datetime.now() - + QUERY_REWRITE_AGENT = LLMAgentCreator.create( - "query_rewrite_agent", - current_datetime=current_datetime + "query_rewrite_agent", current_datetime=current_datetime ) - + SQL_QUERY_GENERATION_AGENT = LLMAgentCreator.create( "sql_query_generation_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) - + SQL_SCHEMA_SELECTION_AGENT = SqlSchemaSelectionAgent( target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) - + SQL_QUERY_CORRECTION_AGENT = LLMAgentCreator.create( "sql_query_correction_agent", target_engine=self.target_engine, engine_specific_rules=self.engine_specific_rules, **self.kwargs, ) - + SQL_DISAMBIGUATION_AGENT = LLMAgentCreator.create( "sql_disambiguation_agent", target_engine=self.target_engine, @@ -101,11 +102,9 @@ def get_all_agents(self): QUESTION_DECOMPOSITION_AGENT = LLMAgentCreator.create( "question_decomposition_agent" ) - + # Auto-responding UserProxyAgent - USER_PROXY = EmptyResponseUserProxyAgent( - name="user_proxy" - ) + USER_PROXY = EmptyResponseUserProxyAgent(name="user_proxy") agents = [ USER_PROXY, @@ -114,7 +113,7 @@ def get_all_agents(self): SQL_SCHEMA_SELECTION_AGENT, SQL_QUERY_CORRECTION_AGENT, QUESTION_DECOMPOSITION_AGENT, - SQL_DISAMBIGUATION_AGENT + SQL_DISAMBIGUATION_AGENT, ] if self.use_query_cache: @@ -192,7 +191,6 @@ def agentic_flow(self): allow_repeated_speaker=False, model_client=LLMModelCreator.get_model("4o-mini"), termination_condition=self.termination_condition, - selector_func=self.selector, selector_func=self.unified_selector, ) return flow diff --git a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py index e0b4435..5e1b603 100644 --- a/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py +++ b/text_2_sql/autogen/src/autogen_text_2_sql/custom_agents/sql_query_cache_agent.py @@ -47,19 +47,21 @@ async def on_messages_stream( # Initialize results dictionary cached_results = { "cached_questions_and_schemas": [], - "contains_pre_run_results": False + "contains_pre_run_results": False, } # Process each question sequentially for question in user_questions: # Fetch the queries from the cache based on the question logging.info(f"Fetching queries from cache for question: {question}") - cached_query = await self.sql_connector.fetch_queries_from_cache(question) - + cached_query = await self.sql_connector.fetch_queries_from_cache( + question + ) + # If any question has pre-run results, set the flag if cached_query.get("contains_pre_run_results", False): cached_results["contains_pre_run_results"] = True - + # Add the cached results for this question if cached_query.get("cached_questions_and_schemas"): cached_results["cached_questions_and_schemas"].extend( @@ -75,7 +77,9 @@ async def on_messages_stream( except json.JSONDecodeError: # If not JSON array, process as single question logging.info(f"Processing single question: {last_response}") - cached_queries = await self.sql_connector.fetch_queries_from_cache(last_response) + cached_queries = await self.sql_connector.fetch_queries_from_cache( + last_response + ) yield Response( chat_message=TextMessage( content=json.dumps(cached_queries), source=self.name diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py index c632a59..7158df5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/ai_search.py @@ -198,12 +198,11 @@ async def get_entity_schemas( logging.info("Search Text: %s", text) retrieval_fields = [ - # "FQN", + "FQN", "Entity", "EntityName", - # "Schema", - # "Definition", - "Description", + "Schema", + "Definition", "Columns", "EntityRelationships", "CompleteEntityRelationshipsGraph", @@ -211,8 +210,7 @@ async def get_entity_schemas( schemas = await self.run_ai_search_query( text, - # ["DefinitionEmbedding"], - ["DescriptionEmbedding"], + ["DefinitionEmbedding"], retrieval_fields, os.environ["AIService__AzureSearchOptions__Text2SqlSchemaStore__Index"], os.environ[ @@ -227,7 +225,7 @@ async def get_entity_schemas( for schema in schemas: filtered_schemas = [] - # del schema["FQN"] + del schema["FQN"] if ( schema["CompleteEntityRelationshipsGraph"] is not None diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py index 8afa76d..cca4cc5 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/connectors/databricks_sql.py @@ -98,13 +98,12 @@ async def get_entity_schemas( ) for schema in schemas: - # schema["SelectFromEntity"] = ".".join( - # [schema["Catalog"], schema["Schema"], schema["Entity"]] - # ) - schema["SelectFromEntity"] = schema["Entity"] + schema["SelectFromEntity"] = ".".join( + [schema["Catalog"], schema["Schema"], schema["Entity"]] + ) del schema["Entity"] - # del schema["Schema"] + del schema["Schema"] del schema["Catalog"] if as_json: diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml index bb1f9d5..f4325f2 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/query_rewrite_agent.yaml @@ -14,7 +14,7 @@ system_message: - Use the current date/time above as reference point - Replace relative dates like 'last month', 'this year', 'previous quarter' with absolute dates - Maintain consistency in date formats (YYYY-MM-DD) - + Examples of date resolution (assuming current date is {{ current_datetime }}): - 'last month' -> specific month name and year - 'this year' -> {{ current_datetime.year }} diff --git a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml index 7f0dea8..6617b26 100644 --- a/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml +++ b/text_2_sql/text_2_sql_core/src/text_2_sql_core/prompts/sql_disambiguation_agent.yaml @@ -48,6 +48,7 @@ system_message: - If you cannot map it to a column, add en entry to the disambiguation list with the clarification question you need from the user: - If there are multiple possible options, or you are unsure how it maps, make sure to ask a clarification question. + - If there are no possible options, ask a clarification question for more detail. { \"disambiguation\": [