diff --git a/search-app/server/app/server.py b/search-app/server/app/server.py index 9cc9291..2b663c5 100644 --- a/search-app/server/app/server.py +++ b/search-app/server/app/server.py @@ -1,32 +1,31 @@ from fastapi import FastAPI, HTTPException -import asyncio -from fastapi import FastAPI, HTTPException -import asyncio from fastapi.responses import RedirectResponse from langserve import add_routes from graph.graph import SpatialRetrieverGraph, State -from langchain_core.runnables import chain +from graph.routers import CollectionRouter from config.config import Config from indexing.indexer import Indexer from connectors.pygeoapi_retriever import PyGeoAPI from connectors.geojson_osm import GeoJSON -from langchain_core.runnables.graph import MermaidDrawMethod from langchain.schema import Document -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import HumanMessage from fastapi.middleware.cors import CORSMiddleware -from .utils import SessionData, cookie, verifier, backend +from .utils import (SessionData, cookie, verifier, backend, + calculate_bounding_box, summarize_feature_collection_properties, + load_conversational_prompts) + from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver from fastapi import HTTPException, FastAPI, Depends, Response, Security from fastapi.security.api_key import APIKeyHeader, APIKey from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver from uuid import UUID, uuid4 -from typing import List, Optional +from typing import List import geojson from pydantic import BaseModel import json - import logging + logging.basicConfig() logging.getLogger().setLevel(logging.INFO) @@ -36,7 +35,9 @@ "http://localhost:5173", # Frontend app origin ] +# Init memory: memory = AsyncSqliteSaver.from_conn_string(":memory:") + ### Get session info via cookie async def get_current_session(session_id: UUID = Depends(cookie), session_data: SessionData = Depends(verifier)): return session_data @@ -54,40 +55,38 @@ async def get_current_session(session_id: UUID = Depends(cookie), session_data: # graph = SpatialRetrieverGraph(State(messages=[], search_criteria="", search_results=[], ready_to_retrieve="")).compile() graph = None session_id = None -""" -# Generate a visualization of the current dialog-module workflow -graph_visualization = graph.get_graph().draw_mermaid_png( - draw_method=MermaidDrawMethod.API, -) -with open("./graph/current_workflow.png", "wb") as f: - f.write(graph_visualization) -""" + # Create a dictionary of indexes indexes = { "pygeoapi": Indexer(index_name="pygeoapi", score_treshold= 0.4, k = 20), -} - -# Add indexer for local geojson with OSM features -geojson_osm_indexer = Indexer(index_name="geojson", + "geojson_osm_indexer": Indexer(index_name="geojson", # Add indexer for local geojson with OSM features score_treshold=-400.0, k = 20, use_hf_model=True, embedding_model="Alibaba-NLP/gte-large-en-v1.5" ) - +} # Add connection to local file including building features # Replace the value for tag_name argument if you have other data geojson_osm_connector = GeoJSON(tag_name="building") - """ +# We can also use a osm/geojson that comes from a web resource local_file_connector = GeoJSON(file_dir="https://webais.demo.52north.org/pygeoapi/collections/dresden_buildings/items", tag_name="building") """ +# Adding conversational routes. We do this here to avoid time-expensive llm calls during inference: +collection_router = CollectionRouter() + +# Check if already custom prompts generated and if yes: check if these match the existing search indexes +conversational_prompts = load_conversational_prompts(collection_router=collection_router) + + + app = FastAPI() app.add_middleware( @@ -127,8 +126,17 @@ async def create_session(response: Response): global graph - graph = SpatialRetrieverGraph(state=State(messages=[], search_criteria="", search_results=[], ready_to_retrieve=""), - thread_id=session_id, memory=memory).compile() + graph = SpatialRetrieverGraph(state=State(messages=[], + search_criteria="", + spatial_context="", + search_results=[], + ready_to_retrieve=""), + thread_id=session_id, + memory=memory, + search_indexes=indexes, + collection_router=collection_router, + conversational_prompts=conversational_prompts + ).compile() data = SessionData(session_id=session_id) @@ -137,23 +145,6 @@ async def create_session(response: Response): return {"message": f"created session for {session}"} -""" -@chain -async def call_graph(query: str, session_id: UUID = Depends(cookie), session_data: SessionData = Depends(verifier)): - if graph is not None: - print(f"-#-#--Running graph---- Using session_id: {str(session_id)}") - print(f"session_data: {session_data}") - inputs = {"messages": [HumanMessage(content=query)]} - graph.graph.thread_id = "test" - response = await graph.ainvoke(inputs) - else: - raise HTTPException(status_code=400, detail="No session created") - return response -""" -@app.get("/test_api_key") -async def test_api_key(api_key: APIKey = Depends(get_api_key)): - return f"Entered API KEY: {api_key}" - class Query(BaseModel): query: str @@ -192,7 +183,7 @@ async def index_geojson_osm(api_key: APIKey = Depends(get_api_key)): # await local_file_connector.add_descriptions_to_features() feature_docs = await geojson_osm_connector._features_to_docs() logging.info(f"Converted {len(feature_docs)} Features or FeatureGroups to documents") - res_local = geojson_osm_indexer._index(documents=feature_docs) + res_local = indexes['geojson_osm_indexer']._index(documents=feature_docs) return res_local def generate_combined_feature_collection(doc_list: List[Document]): @@ -208,15 +199,26 @@ def generate_combined_feature_collection(doc_list: List[Document]): features.extend(feature_list) combined_feature_collection = geojson.FeatureCollection(features) - geojson_str = geojson.dumps(combined_feature_collection, sort_keys=True, indent=2) + # geojson_str = geojson.dumps(combined_feature_collection, sort_keys=True, indent=2) - return geojson_str + return combined_feature_collection @app.get("/retrieve_geojson") async def retrieve_geojson(query: str): - features = geojson_osm_indexer.retriever.invoke(query) + features = indexes['geojson_osm_indexer'].retriever.invoke(query) + + feature_collection = generate_combined_feature_collection(features) + + spatial_extent = calculate_bounding_box(feature_collection) + properties = summarize_feature_collection_properties(feature_collection) + + summary = f"""Summary of found features: + {properties} + + Spatial Extent of all features: {spatial_extent} + """ - return generate_combined_feature_collection(features) + return feature_collection, summary @app.get("/clear_index") @@ -226,7 +228,7 @@ async def clear_index(index_name: str, api_key: APIKey = Depends(get_api_key)): if index_name == 'geojson': logging.info("Clearing geojson index") - geojson_osm_indexer._clear() + indexes['geojson_osm_indexer']._clear() else: logging.info(f"Clearing index: {index_name}") indexes[index_name]._clear() @@ -256,4 +258,4 @@ async def remove_doc_from_index(index_name: str, _id: str, api_key: APIKey = Dep if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", reload=False, port=8000) diff --git a/search-app/server/app/utils.py b/search-app/server/app/utils.py index 45d4169..6026ad3 100644 --- a/search-app/server/app/utils.py +++ b/search-app/server/app/utils.py @@ -4,6 +4,17 @@ from fastapi_sessions.backends.implementations import InMemoryBackend from fastapi_sessions.session_verifier import SessionVerifier from fastapi_sessions.frontends.implementations import SessionCookie, CookieParameters +import os +import importlib +import json +import sys +from pathlib import Path +import logging + + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + class SessionData(BaseModel): session_id: str @@ -52,3 +63,98 @@ def verify_session(self, model: SessionData) -> bool: backend=backend, auth_http_exception=HTTPException(status_code=403, detail="Invalid session"), ) + +### Geojson utilities +def calculate_bounding_box(geojson): + min_lng, min_lat = float('inf'), float('inf') + max_lng, max_lat = float('-inf'), float('-inf') + + def extract_coordinates(geometry): + if geometry['type'] == 'Point': + return [geometry['coordinates']] + elif geometry['type'] in ['MultiPoint', 'LineString']: + return geometry['coordinates'] + elif geometry['type'] in ['MultiLineString', 'Polygon']: + return [coord for line in geometry['coordinates'] for coord in line] + elif geometry['type'] == 'MultiPolygon': + return [coord for poly in geometry['coordinates'] for line in poly for coord in line] + else: + return [] + + for feature in geojson['features']: + coords = extract_coordinates(feature['geometry']) + for coord in coords: + lng, lat = coord + min_lng = min(min_lng, lng) + min_lat = min(min_lat, lat) + max_lng = max(max_lng, lng) + max_lat = max(max_lat, lat) + + return [min_lng, min_lat, max_lng, max_lat] + + +def summarize_feature_collection_properties(feature_collection): + + data = list(map(lambda f: f['properties'], feature_collection['features'])) + + summary = {} + + for item in data: + item_type = item.get('type', '') + description = item.get('description', '') + + if item_type not in summary: + summary[item_type] = {'count': 0, 'descriptions': []} + + summary[item_type]['count'] += 1 + + if description and description not in summary[item_type]['descriptions']: + summary[item_type]['descriptions'].append(description) + + summary_text = "" + for item_type, details in summary.items(): + summary_text += f"Type: {item_type} (Count: {details['count']})\nDescriptions:\n" + for desc in details['descriptions']: + summary_text += f"- {desc}\n" + summary_text += "\n" + + return summary_text.strip() + + +### Custom prompt utilities + +def save_conversational_prompts(file_name, conversational_prompts): + with open(file_name, 'w') as f: + json.dump(conversational_prompts, f, indent=4) # Pretty print with indentation + + +def read_dict_from_module(module_path): + module_name = Path(module_path).stem + if os.path.exists(f"{module_path}"): + try: + from graph.custom_prompts.custom_prompts import prompts + return prompts + except ImportError: + return None + else: + logging.info(f"Module '{module_name}.py' does not exist.") + return None + +def write_dict_to_file(dictionary, filename): + with open(filename, 'w') as file: + file.write(f"prompts = {repr(dictionary)}\n") + +def load_conversational_prompts(collection_router): + loaded_dict = read_dict_from_module('./graph/custom_prompts/custom_prompts.py') + collection_names = [c['collection_name'] for c in collection_router.coll_dicts] + + if loaded_dict and set(loaded_dict.keys()) == set(collection_names): + logging.info("Custom prompts already generated for current collections. Reading it from file...") + conversational_prompts = loaded_dict + else: + conversational_prompts = collection_router.generate_conversation_prompts() + write_dict_to_file(conversational_prompts,'./graph/custom_prompts/custom_prompts.py') + + return conversational_prompts + + diff --git a/search-app/server/connectors/__init__.py b/search-app/server/connectors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search-app/server/connectors/geojson_osm.py b/search-app/server/connectors/geojson_osm.py index 4f9d7eb..dbd5a41 100644 --- a/search-app/server/connectors/geojson_osm.py +++ b/search-app/server/connectors/geojson_osm.py @@ -59,15 +59,19 @@ class GeoJSON(): _features_to_docs() -> List[Document]: Converts features into a list of Document objects for further use. """ - def __init__(self, file_dir: str = None, tag_name: str = "building"): + def __init__(self, file_dir: str = None, tag_name: str = None): if file_dir and is_url(file_dir): """We assume the online resource to be a collection published via a PyGeoAPI instance""" logging.info("Getting features from online resource") params = {"f": "json", "limit": 10000} gj = self._fetch_features_from_online_resource(file_dir, params) print(f"Retrieved {len(gj)} features") - - self.features = self._filter_meaningful_features(gj, tag_name) + + self.tag_name = tag_name + if self.tag_name: + self.features = self._filter_meaningful_features(gj, self.tag_name) + else: + self.features = gj else: if not file_dir: file_dir = config.local_geojson_files @@ -180,7 +184,8 @@ def _get_feature_description(self, feature): return "\n".join(description_parts) async def _features_to_docs(self) -> List[Document]: - await self.add_descriptions_to_features() + if self.tag_name: + await self.add_descriptions_to_features() # Part 1: Create documents for features with names features_with_names = list(filter(lambda feature: feature if feature["properties"].get("name", "") else None, self.features)) diff --git a/search-app/server/graph/actions.py b/search-app/server/graph/actions.py index 614d4d7..d981a80 100644 --- a/search-app/server/graph/actions.py +++ b/search-app/server/graph/actions.py @@ -2,57 +2,127 @@ generate_conversation_prompt, generate_final_answer_prompt ) + +from .spatial_utilities import ( + generate_spatial_context_chain +) from langchain_openai import ChatOpenAI from config.config import Config from langchain_openai import OpenAI -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.runnables.history import RunnableWithMessageHistory - -from langchain_core.pydantic_v1 import BaseModel, Field -from langchain.output_parsers.json import SimpleJsonOutputParser from langchain_core.output_parsers import StrOutputParser import json +from langchain.tools import tool +from typing import Literal +import requests +import logging + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + config = Config('./config/config.json') OPENAI_API_KEY = config.openai_api_key TAVILY_API_KEY = config.tavily_api_key -llm = ChatOpenAI(model="gpt-3.5-turbo-0125", +llm_with_structured_output = ChatOpenAI(model="gpt-3.5-turbo-0125", model_kwargs={ "response_format": { "type": "json_object" } }) -final_answer_llm = OpenAI(temperature=0) +llm_unstructured, final_answer_llm = OpenAI(temperature=0), OpenAI(temperature=0) -# Chains -conversation_chain = generate_conversation_prompt()| llm +def is_valid_json(myjson): + try: + json_object = json.loads(myjson) + except ValueError as e: + return False + return True -""" -chain_with_message_history = RunnableWithMessageHistory( - conversation_chain, - lambda session_id: memory, - input_messages_key="input", - history_messages_key="chat_history", -) +def run_converstation_chain(input: str, chat_history, prompt=None): + # Chains + conversation_chain = generate_conversation_prompt(system_prompt=prompt)| llm_with_structured_output -# Example usage -def run_converstation_chain(input: str): - history = chain_with_message_history.invoke( - {"input": input}, - {"configurable": {"session_id": "unused"}}, - ) - parsed_dict = json.loads(history.content) + logging.info(f"input to converation chain: {input}") - return history, parsed_dict -""" -def run_converstation_chain(input: str, chat_history): history = conversation_chain.invoke( {"input": input, "chat_history": chat_history} ) - parsed_dict = json.loads(history.content) + if history.content and is_valid_json(history.content): + parsed_dict = json.loads(history.content) + else: parsed_dict = {} + return history, parsed_dict final_answer_chain = ( generate_final_answer_prompt() | final_answer_llm | StrOutputParser() -) \ No newline at end of file +) + + +###### RETRIEVAL +## Search tool (dummy here. to be replaced) +@tool("search_tool") +def search_tool( + query_string: str, + index_name: str): + "Takes a query_string and a index_name as input and searches for data" + if index_name == "geojson": + response = requests.get(f"http://localhost:8000/retrieve_geojson?query={query_string}") + + else: + url = f"http://localhost:8000/retrieve_{index_name}/invoke" + json = {"input": query_string} + response = requests.post(url=url, json=json) + "Takes a search_dict as input and searches for data" + if response.status_code == 200: + docs = response.json() + return docs + + +### Custom search tools factory +def generate_search_tool(coll_dict): + collection_name = coll_dict['collection_name'] + collection_description = coll_dict.get('description', '') + + if collection_description: + docstring = f"Finds information in following collection: {collection_description}" + else: + docstring = f"Finds information in following collection: {collection_name}" + + + @tool(f"search_{collection_name}") + def search_tool(query: str, + search_index, + search_type: Literal["similarity", + "mmr", + "similarity_score_threshold"]="similarity", + score_treshold: float=0.5, + k: int=20): + """""" + + search_kwargs={"score_threshold": score_treshold, + "k": k} + + if search_type == "similarity": + retriever = search_index.vectorstore.as_retriever(search_kwargs={"k": k}) + else: + retriever = search_index.vectorstore.as_retriever(search_type=search_type, + search_kwargs=search_kwargs) + + docs = retriever.invoke(query) + + return docs + + search_tool.__doc__ = docstring + search_tool.func.__name__ = f"search_{collection_name}" + + return search_tool + + +### spatial context extraction +@tool("spatial_context_extraction_tool") +def spatial_context_extraction_tool(query: str): + """This tool extracts the spatial entities, scale and extent of a query""" + chain = generate_spatial_context_chain(llm=llm_unstructured) + + return chain.invoke({"query": query}) \ No newline at end of file diff --git a/search-app/server/graph/custom_prompts/__init__.py b/search-app/server/graph/custom_prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/search-app/server/graph/custom_prompts/custom_prompts.py b/search-app/server/graph/custom_prompts/custom_prompts.py new file mode 100644 index 0000000..5ba959c --- /dev/null +++ b/search-app/server/graph/custom_prompts/custom_prompts.py @@ -0,0 +1,2 @@ +prompts = {'geojson': '**AI Instructions:**\nYou are an AI designed to assist users in finding environmental or geospatial datasets. Follow these guidelines:\n1. **Extract Search Criteria:**\n2. **Refine the Search:** If the request is vague, ask follow-up questions about buildings or building types. Only re-ask a maximum of 3 times per inquiry and try to ask as few questions as possible. Use bold formatting (markdown) to highlight important aspects in your response.\n3. **Contextual Responses:** Keep track of the conversation context to use previous responses in refining the search.\n4. **Determine Readiness for Search:**\n - **Flag as Ready:** As soon as you have enough details to perform a meaningful search or if the user implies they want to proceed with the search, set the flag `"ready_to_retrieve": "yes"`.\n - **Avoid Over-Questioning:** If you sense the user is ready to search based on their input (e.g., "Sure, search for...", "That should be enough...", "Go ahead and find the data..."), immediately set the flag `"ready_to_retrieve": "yes"` and stop asking further questions.\n5. **Generate Search Query:** Once enough details are gathered, create a search string that combines all specified criteria.\n\n**Output Requirements:**\n - Always output a JSON object with an `"answer"` key (containing your response) and a `"search_criteria"` key (containing the extracted criteria).\n - If the search is ready to proceed, include `"ready_to_retrieve": "yes"` in the JSON object.\n\n**Tips for Natural Interaction:**\n- Maintain a friendly and conversational tone.\n- Acknowledge user inputs and express appreciation for their responses.\n- Keep responses clear and straightforward while ensuring they meet the user\'s needs.\n\n**Example Conversations:**\n1. User: Can you help me find information about different types of buildings?\n AI: Sure! Could you provide more details on the type of buildings you are interested in? For example, residential, commercial, or industrial buildings?\n User: I\'m looking for information on residential buildings.\n AI: Great choice! Let\'s start by focusing on residential buildings. Do you have any specific preferences such as building levels or architectural styles?\n User: I\'m interested in residential buildings with multiple levels.\n AI: Understood! I will include the criteria for residential buildings with multiple levels in the search.\n\n2. User: I need data on commercial buildings in a specific area.\n AI: Of course! Can you provide more details about the specific area you are interested in for commercial buildings?\n User: I\'m looking for commercial buildings in Dresden.\n AI: Got it! We will focus on commercial buildings in the Dresden area for the search.\n\n3. User: I want to learn about different types of public buildings.\n AI: Sure thing! Could you specify the type of public buildings you are interested in? For example, government offices, schools, or healthcare facilities?\n User: I\'m interested in government office buildings.\n AI: Perfect! Let\'s narrow down the search to focus on government office buildings. Thank you for the clarification!', + 'pygeoapi': '\n**AI Instructions:**\nYou are an AI designed to assist users in finding environmental or geospatial datasets. Follow these guidelines: \n1. **Extract Search Criteria:**\n2. **Refine the Search:** If the request is vague, ask follow-up questions about the LTCE collection, weather data, or climate projections. Only re-ask a maximum of 3 times per inquiry and try to ask as few questions as possible. Use bold formatting (markdown) to highlight important aspects in your response. \n3. **Contextual Responses:** Keep track of the conversation context to use previous responses in refining the search. \n4. **Determine Readiness for Search:**\n - **Flag as Ready:** As soon as you have enough details to perform a meaningful search or if the user implies they want to proceed with the search, set the flag `"ready_to_retrieve": "yes"`.\n - **Avoid Over-Questioning:** If you sense the user is ready to search based on their input (e.g., "Sure, search for...", "That should be enough...", "Go ahead and find the data..."), immediately set the flag `"ready_to_retrieve": "yes"` and stop asking further questions.\n5. **Generate Search Query:** Once enough details are gathered, create a search string that combines all specified criteria.\n\n**Output Requirements:**\n - Always output a JSON object with an `"answer"` key (containing your response) and a `"search_criteria"` key (containing the extracted criteria).\n - If the search is ready to proceed, include `"ready_to_retrieve": "yes"` in the JSON object.\n\n**Tips for Natural Interaction:**\n- Maintain a friendly and conversational tone.\n- Acknowledge user inputs and express appreciation for their responses.\n- Keep responses clear and straightforward while ensuring they meet the user\'s needs.\n\n\n**Example Conversations:** \n1. User: Can you help me find daily climate records data?\n AI: Great! Could you specify if you are looking for daily climate records related to precipitation, temperature, or both?\n User: I am interested in daily extremes of temperature.\n AI: Understood! Let\'s focus on daily extremes of temperature. Any specific location or time period you are looking for?\n User: I am interested in historical data from the 1800s.\n AI: Perfect! I have noted that you are interested in historical daily extremes of temperature data from the 1800s. We are ready to proceed with the search. \n\n2. User: I need information on climate projections for Canada.\n AI: Sure! Are you looking for seasonal projections, monthly projections, or both?\n User: I am interested in both seasonal and monthly projections.\n AI: Got it! We will search for both seasonal and monthly climate projections for Canada. Any specific time frame you have in mind?\n User: I am interested in projections for the period from 2015 to 2100.\n AI: Noted! We will search for seasonal and monthly climate projections for Canada from 2015 to 2100. We are ready to proceed with the search.\n\n3. User: Can you find real-time hydrometric data for me?\n AI: Absolutely! Could you specify if you are looking for real-time water level data, flow data, or both?\n User: I am interested in real-time water level data.\n AI: Great choice! We will focus on real-time water level data. Any specific region or station you are interested in?\n User: I am interested in hydrometric stations across Canada.\n AI: Perfect! We will search for real-time water level data collected at hydrometric stations across Canada. We are ready to proceed with the search.'} diff --git a/search-app/server/graph/graph.py b/search-app/server/graph/graph.py index b2df17f..32a3585 100644 --- a/search-app/server/graph/graph.py +++ b/search-app/server/graph/graph.py @@ -10,26 +10,56 @@ from .actions import ( run_converstation_chain, final_answer_chain, + generate_search_tool, + spatial_context_extraction_tool ) +import logging +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + +def is_valid_json(myjson): + try: + json_object = json.loads(myjson) + except ValueError as e: + return False + return True class State(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] search_criteria: str + spatial_context: str search_results: List ready_to_retrieve: str + index_name: str class SpatialRetrieverGraph(StateGraph): - def __init__(self, state: State, thread_id: int, memory): + def __init__(self, + state: State, + thread_id: int, + memory, + search_indexes: dict, + collection_router, + conversational_prompts: dict=None + ): super().__init__(state) self.setup_graph() self.counter = 0 self.thread_id = thread_id self.memory = memory - + self.search_indexes = search_indexes + self.conversational_prompts = conversational_prompts + self.route_layer = collection_router.rl + self.search_indexes = search_indexes + self.collection_info_dict = collection_router.coll_dicts + + # This takes generates an individual search tool for all collections available + self.search_tools = {c['collection_name']: generate_search_tool(c) for c in self.collection_info_dict} + def setup_graph(self): self.add_node("conversation", self.run_conversation) - self.add_node("tavily_search", self.run_tavily_search) + self.add_node("extract_spatial_context", self.extract_spatial_context) + self.add_node("search", self.run_search) self.add_node("final_answer", self.final_answer) self.add_node("save_state", self.save_state) @@ -38,10 +68,12 @@ def setup_graph(self): self.should_continue, { "human": "save_state", - "tavily_search": "tavily_search" + "extract_spatial_context": "extract_spatial_context" } ) - self.add_edge("tavily_search", "final_answer") + + self.add_edge("extract_spatial_context", "search") + self.add_edge("search", "final_answer") self.add_edge("final_answer", "save_state") self.add_edge("save_state", END) @@ -56,33 +88,83 @@ async def run_conversation(self, state: State): else: print("---start conversation (no previous messages)") chat_history = [] + + # check if already search_criteria in state. if yes, use semantic router to choose prompt and correct search index + search_criteria = state.get("search_criteria", "") + if search_criteria: + route_choice = self.route_layer(search_criteria) + else: + route_choice = self.route_layer(state["messages"][-1].content) - response, parsed_dict = run_converstation_chain(input=state["messages"][-1].content, - chat_history=chat_history) - answer = json.loads(response.content).get("answer", "") + prompt = None + if route_choice.name: + logging.info(f"Chosen route: {route_choice.name}") + state["index_name"] = route_choice.name + prompt = self.conversational_prompts[route_choice.name] + else: + logging.info("No route chosen, routing to default") + + logging.info(f"Custom prompt:{prompt}") + response, parsed_dict = run_converstation_chain(input=state["messages"][-1].content, + chat_history=chat_history, + prompt=prompt) + + if response.content and is_valid_json(response.content): + answer = json.loads(response.content).get("answer", "") + else: + answer = "Sorry, I am only designed to help you with finding data. Please try again typing your request :)" state["messages"].append(AIMessage(content=answer)) state["search_criteria"] = parsed_dict.get("search_criteria", "") state["ready_to_retrieve"] = parsed_dict.get("ready_to_retrieve", "no") return state - def run_tavily_search(self, state: State): - print("---running a tavily search") - tavily_search = TavilySearchResults() - search_results = tavily_search.invoke(state["search_criteria"]) + def extract_spatial_context(self, state: State): + print("---extracting spatial context of search") + spatial_context = spatial_context_extraction_tool.invoke({"query": str(state['search_criteria'])}) + state['spatial_context'] = spatial_context + + logging.info(f"Extracted following spatial context: {spatial_context}") + return state + + def run_search(self, state: State): + print("---running a search") + logging.info(f"Search criteria used: {state['search_criteria']}") + index_name = state.get("index_name", "") + + search_index = self.search_indexes[index_name] + search_tool = self.search_tools[index_name] + + if index_name: + logging.info(f"Starting search in index: {index_name} using this tool: {search_tool.name}") + + search_results = search_tool.invoke({"query": str(state['search_criteria']), + "search_index": search_index, + "search_type": "similarity", + "k": 3}) + else: + tavily_search = TavilySearchResults() + search_results = tavily_search.invoke(state["search_criteria"]) + state["search_results"] = search_results + state["messages"].append(AIMessage(content=f"Search results: {search_results}")) return state def should_continue(self, state: State) -> str: if state.get("ready_to_retrieve") == "yes": - print("---routing to search") - return "tavily_search" + print("---routing to spatial context extractor, then to search") + return "extract_spatial_context" else: return "human" def final_answer(self, state: State) -> str: - query, context = state["search_criteria"], state["search_results"] + if state["index_name"] == "geojson": + logging.info(f"I found: {state['search_results'][-1]}") + context = state["search_results"][-1] + else: + context = state["search_results"] + query = state["search_criteria"] answer = final_answer_chain.invoke({"query": query, "context": context}).strip() diff --git a/search-app/server/graph/prompts.py b/search-app/server/graph/prompts.py index d5870af..1c80dd5 100644 --- a/search-app/server/graph/prompts.py +++ b/search-app/server/graph/prompts.py @@ -1,64 +1,77 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.prompts import PromptTemplate +import logging -def generate_conversation_prompt(): - system_prompt = """ - **AI Instructions:**ยด - You are an AI designed to assist users in finding environmental or geospatial datasets. Follow these guidelines: +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) - 1. **Extract Search Criteria:** Identify the specific type of environmental or geospatial data the user is requesting. - 2. **Refine the Search:** If the request is vague, ask follow-up questions about the time period, geographic area, resolution, or format to gather more details. Only re-ask maximum of 3 times per inquery and try to ask as less as possible. Use bold formatting (markdown) to highlight important aspects in your response. - 3. **Contextual Responses:** Keep track of the conversation context to use previous responses in refining the search. - 4. **Generate Search Query:** Once enough details are gathered, create a search string that combines all specified criteria. - 'You must always output a JSON object with an "answer" key and a "search_criteria" key.' - If you have the impression that the user gives the go to search, do not ask follow-up questions and add a flag "ready_to_retrieve": "yes". +def generate_conversation_prompt(system_prompt=None): + if not system_prompt: + logging.info("Using default system prompt") + system_prompt = """ + **AI Instructions:** + You are an AI designed to assist users in finding environmental or geospatial datasets. Follow these guidelines: + 1. **Extract Search Criteria:** + 2. **Refine the Search:** If the request is vague, ask follow-up questions about . Only re-ask a maximum of 3 times per inquiry and try to ask as few questions as possible. Use bold formatting (markdown) to highlight important aspects in your response. + 3. **Contextual Responses:** Keep track of the conversation context to use previous responses in refining the search. + 4. **Determine Readiness for Search:** + - **Flag as Ready:** As soon as you have enough details to perform a meaningful search or if the user implies they want to proceed with the search, set the flag `"ready_to_retrieve": "yes"`. + - **Avoid Over-Questioning:** If you sense the user is ready to search based on their input (e.g., "Sure, search for...", "That should be enough...", "Go ahead and find the data..."), immediately set the flag `"ready_to_retrieve": "yes"` and stop asking further questions. + 5. **Generate Search Query:** Once enough details are gathered, create a search string that combines all specified criteria. - **Tips for Natural Interaction:** - - Maintain a friendly and conversational tone. - - Acknowledge user inputs and express appreciation for their responses. - - Keep responses clear and straightforward while ensuring they meet the user's needs. + **Output Requirements:** + - Always output a JSON object with an `"answer"` key (containing your response) and a `"search_criteria"` key (containing the extracted criteria). + - If the search is ready to proceed, include `"ready_to_retrieve": "yes"` in the JSON object. - **Example Conversations:** + **Tips for Natural Interaction:** + - Maintain a friendly and conversational tone. + - Acknowledge user inputs and express appreciation for their responses. + - Keep responses clear and straightforward while ensuring they meet the user's needs. - **Example 1:** + **Example Conversations:** - **User:** "I need environmental/geospatial datasets." + **Example 1:** - **Assistant:** "Sure, I can help you find environmental or geospatial datasets. Could you please provide more details about what specific **type of data** you're looking for? For example, are you interested in climate data, land use data, topographic data, or something else?" + **User:** "I need environmental/geospatial datasets." - **Example 2:** + **Assistant:** "Sure, I can help you find environmental or geospatial datasets. Could you please provide more details about what specific **type of data** you're looking for? For example, are you interested in climate data, land use data, topographic data, or something else?" - **User:** "I need precipitation data." + **Example 2:** - **Assistant:** "Great, you're looking for precipitation data. To narrow down the search and provide you with the most relevant datasets, it would be helpful to know if you have any **specific requirements**. For instance, could you specify the **time period** or **geographic area of interest**? You might also want to indicate the resolution or format of the data if applicable." + **User:** "I need precipitation data." - **User:** "2020" + **Assistant:** "Great, you're looking for precipitation data. To narrow down the search and provide you with the most relevant datasets, it would be helpful to know if you have any **specific requirements**. For instance, could you specify the **time period** or **geographic area of interest**? You might also want to indicate the resolution or format of the data if applicable." - **Assistant:** "Thank you. You mentioned you're looking for **precipitation data** for the **year 2020**. Is there a **particular geographic area** you're interested in, such as a specific **country, region, or city**?" + **User:** "2020" - **User:** "Europe" + **Assistant:** "Thank you. You mentioned you're looking for **precipitation data** for the **year 2020**. Is there a **particular geographic area** you're interested in, such as a specific **country, region, or city**?" - **Assistant:** "Got it. You're looking for **precipitation data** for **Europe** in the **year 2020**. I will generate a search query based on this information." + **User:** "Europe" - **Generating Search Query:** "precipitation data Europe 2020" + **Assistant:** "Got it. You're looking for **precipitation data** for **Europe** in the **year 2020**. I will generate a search query based on this information." - **Example 3:** + **Generating Search Query:** "precipitation data Europe 2020" - **User:** "I need land use data." + **Example 3:** - **Assistant:** "Sure, you're looking for land use data. Could you please specify if there's a particular **region** or **time period** you're interested in? Additionally, do you need this data at a specific **resolution** or in a **particular format**?" + **User:** "I need land use data." - **User:** "North America, last decade" + **Assistant:** "Sure, you're looking for land use data. Could you please specify if there's a particular **region** or **time period** you're interested in? Additionally, do you need this data at a specific **resolution** or in a **particular format**?" - **Assistant:** "Understood, you need land use data for North America over the last decade. Is there any specific format or resolution you prefer?" + **User:** "North America, last decade" - **User:** "High resolution" + **Assistant:** "Understood, you need land use data for North America over the last decade. Is there any specific format or resolution you prefer?" - **Assistant:** "Thank you for providing more details. I'll generate a search query for high-resolution land use data for North America over the last decade." + **User:** "High resolution" - **Generating Search Query:** "high resolution land use data North America 2010-2020" - """ + **Assistant:** "Thank you for providing more details. I'll generate a search query for high-resolution land use data for North America over the last decade." + + **Generating Search Query:** "high resolution land use data North America 2010-2020" + """ + else: + logging.info("Using custom system prompt") + prompt = ChatPromptTemplate.from_messages( [ @@ -70,6 +83,10 @@ def generate_conversation_prompt(): ("human", "{input}"), ], ) + + # Explicitely setting the input variables here because sometimes it was hallucinating other input variables. + prompt.input_variables = ['chat_history', 'input'] + prompt.messages[0].prompt.input_variables=[] return prompt @@ -77,7 +94,8 @@ def generate_conversation_prompt(): def generate_final_answer_prompt(): final_answer_prompt = PromptTemplate( template=""" - You are an assistant for question-answering tasks. + You are an assistant for question-answering tasks related to data search. + The question wil be a query and the context either the found datasets or a summary of the recieved data. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise Question: {query} diff --git a/search-app/server/graph/routers.py b/search-app/server/graph/routers.py new file mode 100644 index 0000000..125fa39 --- /dev/null +++ b/search-app/server/graph/routers.py @@ -0,0 +1,142 @@ +import os +from langchain_openai import ChatOpenAI +from langchain_core.output_parsers import StrOutputParser +from langchain.prompts import PromptTemplate +from config.config import Config +import chromadb +from chromadb.config import Settings +import re +from semantic_router import Route +from semantic_router.layer import RouteLayer +from semantic_router.encoders import OpenAIEncoder +import logging + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + + + +config = Config('./config/config.json') + +OPENAI_API_KEY = config.openai_api_key + +class CollectionRouter(): + def __init__(self, persist_dir: str="../server/chroma_db"): + self.llm = ChatOpenAI(model="gpt-3.5-turbo-0125") + self.encoder = OpenAIEncoder() + + self.coll_dicts = self.get_collection_info(persist_dir=persist_dir) + self.routes = [self.generate_route(collection_dict=coll) for coll in self.coll_dicts] + + self.rl = RouteLayer(encoder=self.encoder, routes=self.routes) + + def get_collection_info(self, persist_dir: str) -> dict: + """ + Fetching information about existing collections + """ + client = chromadb.Client(Settings(is_persistent=True, + persist_directory=persist_dir, + )) + + collections = client.list_collections() + + coll_dicts = [] + for c in collections: + print(f"Looking into collection {c.name}") + if c.name != "langchain": + coll = client.get_collection(c.name) + sample_docs = coll.peek()["documents"] + coll_dicts.append({ + "collection_name": c.name, + "embedding_model": c._embedding_function.MODEL_NAME, + "sample_docs": sample_docs + }) + return coll_dicts + + def generate_route(self, collection_dict: dict): + prompt_collection_desc = PromptTemplate( + template="""You receive a collection from a vector database. + Based on the collection's name and a few sample documents, you get an idea of the collection's contents, including its theme, type of data, and any notable characteristics. + Now, create a numbered list of example queries that users might submit to the + vector store to retrieve relevant information from this collection (ignore location references and proper names that can occur in the samples and generate generic queries)." + Additionally, generate a brief description (60 words maximum) of the collection's contents, including its theme, type of data, and any notable characteristics + + Collection:{collection}""", + input_variables= ["collection"], + ) + + coll_chain = ( + prompt_collection_desc + | self.llm + | StrOutputParser() + ) + + + result = coll_chain.invoke({"collection": collection_dict}) + + # Split the input text into the list and description + parts = result.split('\n\nDescription:') + + # Handle the numbered list + list_text = parts[0] + utterances = re.sub(r'\d+\.\s', '', list_text).split('\n') + + # Handle the description + description = parts[1] if len(parts) > 1 else "" + + # Write the collection description into the coll_dict + for c in self.coll_dicts: + if c["collection_name"] == collection_dict["collection_name"]: + c["description"] = description + + route = Route( + name=collection_dict['collection_name'], + description=description, + score_threshold=0.7, + utterances=[u for u in utterances if u], + ) + return route + + def generate_conversation_prompts(self): + prompt = PromptTemplate( + template=""" + You recieve a collection from a vector database. According to the collection name and sample docs, write a prompt that can be used for an agent that shall assist users in finding data. + Ignore possible spatial references (like place names) in the sample docs and generate a generic prompt. + Use the following structure: + ``` + **AI Instructions:** + You are an AI designed to assist users in finding environmental or geospatial datasets. Follow these guidelines: + 1. **Extract Search Criteria:** + 2. **Refine the Search:** If the request is vague, ask follow-up questions about . Only re-ask a maximum of 3 times per inquiry and try to ask as few questions as possible. Use bold formatting (markdown) to highlight important aspects in your response. + 3. **Contextual Responses:** Keep track of the conversation context to use previous responses in refining the search. + 4. **Determine Readiness for Search:** + - **Flag as Ready:** As soon as you have enough details to perform a meaningful search or if the user implies they want to proceed with the search, set the flag `"ready_to_retrieve": "yes"`. + - **Avoid Over-Questioning:** If you sense the user is ready to search based on their input (e.g., "Sure, search for...", "That should be enough...", "Go ahead and find the data..."), immediately set the flag `"ready_to_retrieve": "yes"` and stop asking further questions. + 5. **Generate Search Query:** Once enough details are gathered, create a search string that combines all specified criteria. + + **Output Requirements:** + - Always output a JSON object with an `"answer"` key (containing your response) and a `"search_criteria"` key (containing the extracted criteria). + - If the search is ready to proceed, include `"ready_to_retrieve": "yes"` in the JSON object. + + **Tips for Natural Interaction:** + - Maintain a friendly and conversational tone. + - Acknowledge user inputs and express appreciation for their responses. + - Keep responses clear and straightforward while ensuring they meet the user's needs. + + + **Example Conversations:** + ``` + + Here is the list: {collection}""", + + input_variables=["collection"], + ) + chain = ( + prompt + | self.llm + | StrOutputParser() + ) + + logging.info("Generating individual prompts for all collections") + prompts = {c['collection_name']: chain.invoke({"collection": c}) for c in self.coll_dicts} + return prompts \ No newline at end of file diff --git a/search-app/server/graph/spatial_utilities.py b/search-app/server/graph/spatial_utilities.py new file mode 100644 index 0000000..1b7f742 --- /dev/null +++ b/search-app/server/graph/spatial_utilities.py @@ -0,0 +1,90 @@ +from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.output_parsers import JsonOutputParser +from langchain.prompts import PromptTemplate +from langchain.tools import tool +import aiohttp +import asyncio +import nest_asyncio + +# Allow nested asyncio.run calls +nest_asyncio.apply() + +# Define your desired data structure. +class SpatialEntity(BaseModel): + original_query: str = Field(description="Get original query as prompted by the user") + spatial: str = Field(description="Get the spatial entity. Can be a location or place or a region") + scale: str = Field(description="Get the spatial scale") + +# Set up a parser + inject instructions into the prompt template. +spatial_context_prompt_parser = JsonOutputParser(pydantic_object=SpatialEntity) + +spatial_context_prompt = PromptTemplate( + template=""" + You are an expert in geography and spatial data. + Your task is to extract from a query spatial entities such as city, country or region names. + Also determine the spatial scale ("Local", "City", "Regional", "National", "Continental", "Global") from the given query. + + Output:{format_instructions}\n{query}\n""", + input_variables=["query"], + partial_variables={"format_instructions": spatial_context_prompt_parser.get_format_instructions()}, +) + +async def query_osm_async(query_dict: dict): + nominatim_url = "https://photon.komoot.io/api" + query = query_dict['spatial'] + params = {"q": query} + url = f"{nominatim_url}?q={params['q']}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + results = await response.json() + simplified_results = [ + { + "name": res["properties"].get("name"), + "country": f"{res['properties'].get('country')}", + "type": res["properties"].get("type"), + "extent": res["properties"].get("extent") + } + for res in results.get("features", []) + ] + return {"results": simplified_results} + else: + return {"error": "Failed to query Nominatim"} +@tool +def search_with_osm_query(original_query: str, spatial: str, scale: str): + """ + Use query and search in osm + """ + query_dict = {'spatial': spatial, 'scale': scale} + results = asyncio.run(query_osm_async(query_dict)) + return {"original_query": original_query, "scale": scale, "results": results} + +osm_picker_prompt = PromptTemplate( + template=""" + You are an expert in geography and spatial data. + Your task is to pick from the results list the best matching candidate according to the query. + If the original query includes a country information, consider this in your selection. + If also consider the type. E.g. if user asks for a 'river' also pick the corresponding result + + Also consider the scale: {scale} + Query: {original_query} + Results: {results} + Output:""", + input_variables=["original_query", "scale", "results"], +) + +def generate_spatial_context_chain(llm): + spatial_context_chain = ( + spatial_context_prompt + | llm + | spatial_context_prompt_parser + | search_with_osm_query + | osm_picker_prompt + | llm + ) + return spatial_context_chain + + +# response = spatial_context_chain.invoke({"query": "I climate data for Berlin"}) +# print(response)