diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_config.py b/api/core/tools/provider/builtin/aws/tools/bedrock_config.py new file mode 100644 index 00000000000000..df96789a47f853 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_config.py @@ -0,0 +1,102 @@ +""" +Configuration classes for AWS Bedrock retrieve and generate API +""" + +from dataclasses import dataclass +from typing import Any, Literal, Optional + + +@dataclass +class TextInferenceConfig: + """Text inference configuration""" + maxTokens: Optional[int] = None + stopSequences: Optional[list[str]] = None + temperature: Optional[float] = None + topP: Optional[float] = None + + +@dataclass +class PerformanceConfig: + """Performance configuration""" + latency: Literal["standard", "optimized"] + + +@dataclass +class PromptTemplate: + """Prompt template configuration""" + textPromptTemplate: str + + +@dataclass +class GuardrailConfig: + """Guardrail configuration""" + guardrailId: str + guardrailVersion: str + + +@dataclass +class GenerationConfig: + """Generation configuration""" + additionalModelRequestFields: Optional[dict[str, Any]] = None + guardrailConfiguration: Optional[GuardrailConfig] = None + inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None + performanceConfig: Optional[PerformanceConfig] = None + promptTemplate: Optional[PromptTemplate] = None + + +@dataclass +class VectorSearchConfig: + """Vector search configuration""" + filter: Optional[dict[str, Any]] = None + numberOfResults: Optional[int] = None + overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None + + +@dataclass +class RetrievalConfig: + """Retrieval configuration""" + vectorSearchConfiguration: VectorSearchConfig + + +@dataclass +class OrchestrationConfig: + """Orchestration configuration""" + additionalModelRequestFields: Optional[dict[str, Any]] = None + inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None + performanceConfig: Optional[PerformanceConfig] = None + promptTemplate: Optional[PromptTemplate] = None + + +@dataclass +class KnowledgeBaseConfig: + """Knowledge base configuration""" + generationConfiguration: GenerationConfig + knowledgeBaseId: str + modelArn: str + orchestrationConfiguration: Optional[OrchestrationConfig] = None + retrievalConfiguration: Optional[RetrievalConfig] = None + + +@dataclass +class SessionConfig: + """Session configuration""" + kmsKeyArn: Optional[str] = None + sessionId: Optional[str] = None + + +@dataclass +class RetrieveAndGenerateConfiguration: + """Retrieve and generate configuration + The use of knowledgeBaseConfiguration or externalSourcesConfiguration depends on the type value + """ + type: str = "KNOWLEDGE_BASE" + knowledgeBaseConfiguration: Optional[KnowledgeBaseConfig] = None + + +@dataclass +class RetrieveAndGenerateConfig: + """Retrieve and generate main configuration""" + input: dict[str, str] + retrieveAndGenerateConfiguration: RetrieveAndGenerateConfiguration + sessionConfiguration: Optional[SessionConfig] = None + sessionId: Optional[str] = None \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.py new file mode 100644 index 00000000000000..50ee2de05be6b1 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.py @@ -0,0 +1,343 @@ +import json +from typing import Any, Optional + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class BedrockRetrieveAndGenerateTool(BuiltinTool): + bedrock_client: Any = None + + def _create_text_inference_config( + self, + max_tokens: Optional[int] = None, + stop_sequences: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + ) -> Optional[dict]: + """Create text inference configuration""" + if any([max_tokens, stop_sequences, temperature, top_p]): + config = {} + if max_tokens is not None: + config["maxTokens"] = max_tokens + if stop_sequences: + try: + config["stopSequences"] = json.loads(stop_sequences) + except json.JSONDecodeError: + config["stopSequences"] = [] + if temperature is not None: + config["temperature"] = temperature + if top_p is not None: + config["topP"] = top_p + return config + return None + + def _create_guardrail_config( + self, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + ) -> Optional[dict]: + """Create guardrail configuration""" + if guardrail_id and guardrail_version: + return { + "guardrailId": guardrail_id, + "guardrailVersion": guardrail_version + } + return None + + def _create_generation_config( + self, + additional_model_fields: Optional[str] = None, + guardrail_config: Optional[dict] = None, + text_inference_config: Optional[dict] = None, + performance_mode: Optional[str] = None, + prompt_template: Optional[str] = None, + ) -> dict: + """Create generation configuration""" + config = {} + + if additional_model_fields: + try: + config["additionalModelRequestFields"] = json.loads(additional_model_fields) + except json.JSONDecodeError: + pass + + if guardrail_config: + config["guardrailConfiguration"] = guardrail_config + + if text_inference_config: + config["inferenceConfig"] = {"textInferenceConfig": text_inference_config} + + if performance_mode: + config["performanceConfig"] = {"latency": performance_mode} + + if prompt_template: + config["promptTemplate"] = {"textPromptTemplate": prompt_template} + + return config + + def _create_orchestration_config( + self, + orchestration_additional_model_fields: Optional[str] = None, + orchestration_text_inference_config: Optional[dict] = None, + orchestration_performance_mode: Optional[str] = None, + orchestration_prompt_template: Optional[str] = None, + ) -> dict: + """Create orchestration configuration""" + config = {} + + if orchestration_additional_model_fields: + try: + config["additionalModelRequestFields"] = json.loads(orchestration_additional_model_fields) + except json.JSONDecodeError: + pass + + if orchestration_text_inference_config: + config["inferenceConfig"] = {"textInferenceConfig": orchestration_text_inference_config} + + if orchestration_performance_mode: + config["performanceConfig"] = {"latency": orchestration_performance_mode} + + if orchestration_prompt_template: + config["promptTemplate"] = {"textPromptTemplate": orchestration_prompt_template} + + return config + + def _create_vector_search_config( + self, + number_of_results: int = 5, + search_type: str = "SEMANTIC", + metadata_filter: Optional[dict] = None, + ) -> dict: + """Create vector search configuration""" + config = { + "numberOfResults": number_of_results, + "overrideSearchType": search_type, + } + + # Only add filter if metadata_filter is not empty + if metadata_filter: + config["filter"] = metadata_filter + + return config + + def _bedrock_retrieve_and_generate( + self, + query: str, + knowledge_base_id: str, + model_arn: str, + # Generation Configuration + additional_model_fields: Optional[str] = None, + guardrail_id: Optional[str] = None, + guardrail_version: Optional[str] = None, + max_tokens: Optional[int] = None, + stop_sequences: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + performance_mode: str = "standard", + prompt_template: Optional[str] = None, + # Orchestration Configuration + orchestration_additional_model_fields: Optional[str] = None, + orchestration_max_tokens: Optional[int] = None, + orchestration_stop_sequences: Optional[str] = None, + orchestration_temperature: Optional[float] = None, + orchestration_top_p: Optional[float] = None, + orchestration_performance_mode: Optional[str] = None, + orchestration_prompt_template: Optional[str] = None, + # Retrieval Configuration + number_of_results: int = 5, + search_type: str = "SEMANTIC", + metadata_filter: Optional[dict] = None, + # Additional Configuration + session_id: Optional[str] = None, + ) -> dict[str, Any]: + try: + # Create text inference configurations + text_inference_config = self._create_text_inference_config( + max_tokens, stop_sequences, temperature, top_p + ) + orchestration_text_inference_config = self._create_text_inference_config( + orchestration_max_tokens, orchestration_stop_sequences, + orchestration_temperature, orchestration_top_p + ) + + # Create guardrail configuration + guardrail_config = self._create_guardrail_config( + guardrail_id, guardrail_version + ) + + # Create vector search configuration + vector_search_config = self._create_vector_search_config( + number_of_results, search_type, metadata_filter + ) + + # Create generation configuration + generation_config = self._create_generation_config( + additional_model_fields, guardrail_config, + text_inference_config, performance_mode, prompt_template + ) + + # Create orchestration configuration + orchestration_config = self._create_orchestration_config( + orchestration_additional_model_fields, + orchestration_text_inference_config, + orchestration_performance_mode, + orchestration_prompt_template + ) + + # Create knowledge base configuration + knowledge_base_config = { + "knowledgeBaseId": knowledge_base_id, + "modelArn": model_arn, + "generationConfiguration": generation_config, + "orchestrationConfiguration": orchestration_config, + "retrievalConfiguration": { + "vectorSearchConfiguration": vector_search_config + } + } + + # Create request configuration + request_config = { + "input": {"text": query}, + "retrieveAndGenerateConfiguration": { + "type": "KNOWLEDGE_BASE", + "knowledgeBaseConfiguration": knowledge_base_config + } + } + + # Add session configuration if provided + if session_id and len(session_id) >= 2: + request_config["sessionConfiguration"] = {"sessionId": session_id} + request_config["sessionId"] = session_id + + # Send request + response = self.bedrock_client.retrieve_and_generate(**request_config) + + # Process response + result = { + "output": response.get("output", {}).get("text", ""), + "citations": [] + } + + # Process citations + for citation in response.get("citations", []): + citation_info = { + "text": citation.get("generatedResponsePart", {}) + .get("textResponsePart", {}) + .get("text", ""), + "references": [] + } + + for ref in citation.get("retrievedReferences", []): + reference = { + "content": ref.get("content", {}).get("text", ""), + "metadata": ref.get("metadata", {}), + "location": None + } + + location = ref.get("location", {}) + if location.get("type") == "S3": + reference["location"] = location.get("s3Location", {}).get("uri") + + citation_info["references"].append(reference) + + result["citations"].append(citation_info) + + return result + + except Exception as e: + raise Exception(f"Error calling Bedrock service: {str(e)}") + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> ToolInvokeMessage: + try: + # Initialize Bedrock client if not already initialized + if not self.bedrock_client: + aws_region = tool_parameters.get("aws_region") + aws_access_key_id = tool_parameters.get("aws_access_key_id") + aws_secret_access_key = tool_parameters.get("aws_secret_access_key") + + client_kwargs = { + "service_name": "bedrock-agent-runtime", + } + if aws_region: + client_kwargs["region_name"] = aws_region + # Only add credentials if both access key and secret key are provided + if aws_access_key_id and aws_secret_access_key: + client_kwargs.update({ + "aws_access_key_id": aws_access_key_id, + "aws_secret_access_key": aws_secret_access_key + }) + + try: + self.bedrock_client = boto3.client(**client_kwargs) + except Exception as e: + return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}") + + # Parse metadata filter if provided + metadata_filter = None + if metadata_filter_str := tool_parameters.get("metadata_filter"): + try: + parsed_filter = json.loads(metadata_filter_str) + if parsed_filter: # Only set if not empty + metadata_filter = parsed_filter + except json.JSONDecodeError: + return self.create_text_message("metadata_filter must be a valid JSON string") + + try: + response = self._bedrock_retrieve_and_generate( + query=tool_parameters["query"], + knowledge_base_id=tool_parameters["knowledge_base_id"], + model_arn=tool_parameters["model_arn"], + # Generation Configuration + additional_model_fields=tool_parameters.get("additional_model_fields"), + guardrail_id=tool_parameters.get("guardrail_id"), + guardrail_version=tool_parameters.get("guardrail_version"), + max_tokens=tool_parameters.get("max_tokens"), + stop_sequences=tool_parameters.get("stop_sequences"), + temperature=tool_parameters.get("temperature"), + top_p=tool_parameters.get("top_p"), + performance_mode=tool_parameters.get("performance_mode", "standard"), + prompt_template=tool_parameters.get("prompt_template"), + # Orchestration Configuration + orchestration_additional_model_fields=tool_parameters.get("orchestration_additional_model_fields"), + orchestration_max_tokens=tool_parameters.get("orchestration_max_tokens"), + orchestration_stop_sequences=tool_parameters.get("orchestration_stop_sequences"), + orchestration_temperature=tool_parameters.get("orchestration_temperature"), + orchestration_top_p=tool_parameters.get("orchestration_top_p"), + orchestration_performance_mode=tool_parameters.get("orchestration_performance_mode"), + orchestration_prompt_template=tool_parameters.get("orchestration_prompt_template"), + # Retrieval Configuration + number_of_results=tool_parameters.get("number_of_results", 5), + search_type=tool_parameters.get("search_type", "SEMANTIC"), + metadata_filter=metadata_filter, + # Additional Configuration + session_id=tool_parameters.get("session_id"), + ) + return self.create_json_message(response) + + except Exception as e: + return self.create_text_message(f"Tool invocation error: {str(e)}") + + except Exception as e: + return self.create_text_message(f"Tool execution error: {str(e)}") + + def validate_parameters(self, parameters: dict[str, Any]) -> None: + """Validate the parameters""" + required_params = ["query", "model_arn", "knowledge_base_id"] + for param in required_params: + if not parameters.get(param): + raise ValueError(f"{param} is required") + + # Validate metadata filter if provided + if metadata_filter_str := parameters.get("metadata_filter"): + try: + if not isinstance(json.loads(metadata_filter_str), dict): + raise ValueError("metadata_filter must be a valid JSON object") + except json.JSONDecodeError: + raise ValueError("metadata_filter must be a valid JSON string") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.yaml new file mode 100644 index 00000000000000..b54fc20a172302 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve_and_generate.yaml @@ -0,0 +1,358 @@ +identity: + name: bedrock_retrieve_and_generate + author: AWS + label: + en_US: Bedrock Retrieve and Generate + zh_Hans: Bedrock检索和生成 + icon: icon.svg + +description: + human: + en_US: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base + zh_Hans: 使用Amazon Bedrock知识库进行信息检索和生成的工具 + llm: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base + +parameters: +# Additional Configuration + - name: session_id + type: string + required: false + label: + en_US: Session ID + zh_Hans: 会话ID + human_description: + en_US: Optional session ID for continuous conversations + zh_Hans: 用于连续对话的可选会话ID + form: form + + # AWS Configuration + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS区域 + human_description: + en_US: AWS region for the Bedrock service + zh_Hans: Bedrock服务的AWS区域 + form: form + + - name: aws_access_key_id + type: string + required: false + label: + en_US: AWS Access Key ID + zh_Hans: AWS访问密钥ID + human_description: + en_US: AWS access key ID for authentication (optional) + zh_Hans: 用于身份验证的AWS访问密钥ID(可选) + form: form + + - name: aws_secret_access_key + type: string + required: false + label: + en_US: AWS Secret Access Key + zh_Hans: AWS秘密访问密钥 + human_description: + en_US: AWS secret access key for authentication (optional) + zh_Hans: 用于身份验证的AWS秘密访问密钥(可选) + form: form + + # Knowledge Base Configuration + - name: knowledge_base_id + type: string + required: true + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: ID of the Bedrock Knowledge Base + zh_Hans: Bedrock知识库的ID + form: form + + - name: model_arn + type: string + required: true + label: + en_US: Model ARN + zh_Hans: 模型ARN + human_description: + en_US: The ARN of the model to use + zh_Hans: 要使用的模型ARN + form: form + + # Retrieval Configuration + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 查询 + human_description: + en_US: The search query to retrieve information + zh_Hans: 用于检索信息的查询语句 + form: llm + + - name: number_of_results + type: number + required: false + label: + en_US: Number of Results + zh_Hans: 结果数量 + human_description: + en_US: Number of results to retrieve (1-10) + zh_Hans: 要检索的结果数量(1-10) + default: 5 + min: 1 + max: 10 + form: form + + - name: search_type + type: select + required: false + label: + en_US: Search Type + zh_Hans: 搜索类型 + human_description: + en_US: Type of search to perform + zh_Hans: 要执行的搜索类型 + default: SEMANTIC + options: + - value: SEMANTIC + label: + en_US: Semantic Search + zh_Hans: 语义搜索 + - value: HYBRID + label: + en_US: Hybrid Search + zh_Hans: 混合搜索 + form: form + + - name: metadata_filter + type: string + required: false + label: + en_US: Metadata Filter + zh_Hans: 元数据过滤器 + human_description: + en_US: JSON formatted filter conditions for metadata, supporting operations like equals, greaterThan, lessThan, etc. + zh_Hans: 元数据的JSON格式过滤条件,支持等于、大于、小于等操作 + default: "{}" + form: form + +# Generation Configuration + - name: guardrail_id + type: string + required: false + label: + en_US: Guardrail ID + zh_Hans: 防护栏ID + human_description: + en_US: ID of the guardrail to apply + zh_Hans: 要应用的防护栏ID + form: form + + - name: guardrail_version + type: string + required: false + label: + en_US: Guardrail Version + zh_Hans: 防护栏版本 + human_description: + en_US: Version of the guardrail to apply + zh_Hans: 要应用的防护栏版本 + form: form + + - name: max_tokens + type: number + required: false + label: + en_US: Maximum Tokens + zh_Hans: 最大令牌数 + human_description: + en_US: Maximum number of tokens to generate + zh_Hans: 生成的最大令牌数 + default: 2048 + form: form + + - name: stop_sequences + type: string + required: false + label: + en_US: Stop Sequences + zh_Hans: 停止序列 + human_description: + en_US: JSON array of strings that will stop generation when encountered + zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止生成 + default: "[]" + form: form + + - name: temperature + type: number + required: false + label: + en_US: Temperature + zh_Hans: 温度 + human_description: + en_US: Controls randomness in the output (0-1) + zh_Hans: 控制输出的随机性(0-1) + default: 0.7 + min: 0 + max: 1 + form: form + + - name: top_p + type: number + required: false + label: + en_US: Top P + zh_Hans: Top P值 + human_description: + en_US: Controls diversity via nucleus sampling (0-1) + zh_Hans: 通过核采样控制多样性(0-1) + default: 0.95 + min: 0 + max: 1 + form: form + + - name: performance_mode + type: select + required: false + label: + en_US: Performance Mode + zh_Hans: 性能模式 + human_description: + en_US: Select performance optimization mode(performanceConfig.latency) + zh_Hans: 选择性能优化模式(performanceConfig.latency) + default: standard + options: + - value: standard + label: + en_US: Standard + zh_Hans: 标准 + - value: optimized + label: + en_US: Optimized + zh_Hans: 优化 + form: form + + - name: prompt_template + type: string + required: false + label: + en_US: Prompt Template + zh_Hans: 提示模板 + human_description: + en_US: Custom prompt template for generation + zh_Hans: 用于生成的自定义提示模板 + form: form + + - name: additional_model_fields + type: string + required: false + label: + en_US: Additional Model Fields + zh_Hans: 额外模型字段 + human_description: + en_US: JSON formatted additional fields for model configuration + zh_Hans: JSON格式的额外模型配置字段 + default: "{}" + form: form + +# Orchestration Configuration + - name: orchestration_max_tokens + type: number + required: false + label: + en_US: Orchestration Maximum Tokens + zh_Hans: 编排最大令牌数 + human_description: + en_US: Maximum number of tokens for orchestration + zh_Hans: 编排过程的最大令牌数 + default: 2048 + form: form + + - name: orchestration_stop_sequences + type: string + required: false + label: + en_US: Orchestration Stop Sequences + zh_Hans: 编排停止序列 + human_description: + en_US: JSON array of strings that will stop orchestration when encountered + zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止编排 + default: "[]" + form: form + + - name: orchestration_temperature + type: number + required: false + label: + en_US: Orchestration Temperature + zh_Hans: 编排温度 + human_description: + en_US: Controls randomness in the orchestration output (0-1) + zh_Hans: 控制编排输出的随机性(0-1) + default: 0.7 + min: 0 + max: 1 + form: form + + - name: orchestration_top_p + type: number + required: false + label: + en_US: Orchestration Top P + zh_Hans: 编排Top P值 + human_description: + en_US: Controls diversity via nucleus sampling in orchestration (0-1) + zh_Hans: 通过核采样控制编排的多样性(0-1) + default: 0.95 + min: 0 + max: 1 + form: form + + - name: orchestration_performance_mode + type: select + required: false + label: + en_US: Orchestration Performance Mode + zh_Hans: 编排性能模式 + human_description: + en_US: Select performance optimization mode for orchestration + zh_Hans: 选择编排的性能优化模式 + default: standard + options: + - value: standard + label: + en_US: Standard + zh_Hans: 标准 + - value: optimized + label: + en_US: Optimized + zh_Hans: 优化 + form: form + + - name: orchestration_prompt_template + type: string + required: false + label: + en_US: Orchestration Prompt Template + zh_Hans: 编排提示模板 + human_description: + en_US: Custom prompt template for orchestration + zh_Hans: 用于编排的自定义提示模板 + form: form + + - name: orchestration_additional_model_fields + type: string + required: false + label: + en_US: Orchestration Additional Model Fields + zh_Hans: 编排额外模型字段 + human_description: + en_US: JSON formatted additional fields for orchestration model configuration + zh_Hans: JSON格式的编排模型额外配置字段 + default: "{}" + form: form \ No newline at end of file diff --git a/api/pyproject.toml b/api/pyproject.toml index 12455a0e63678d..8c4d5fd283f880 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -21,7 +21,7 @@ azure-ai-inference = "~1.0.0b3" azure-ai-ml = "~1.20.0" azure-identity = "1.16.1" beautifulsoup4 = "4.12.2" -boto3 = "1.35.74" +boto3 = "1.36.4" bs4 = "~0.0.1" cachetools = "~5.3.0" celery = "~5.4.0"