From 42adc2ab7ee2e32e5898231b26d701c0aa0bc2b5 Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Tue, 10 Sep 2024 13:56:16 +0100 Subject: [PATCH] Update the deployment script --- ai_search_with_adi/ai_search/ai_search.py | 88 +++++++++++-------- ai_search_with_adi/ai_search/deploy.py | 43 +++------ ai_search_with_adi/ai_search/rag_documents.py | 28 +++--- 3 files changed, 78 insertions(+), 81 deletions(-) diff --git a/ai_search_with_adi/ai_search/ai_search.py b/ai_search_with_adi/ai_search/ai_search.py index d2c9ba5..5527a31 100644 --- a/ai_search_with_adi/ai_search/ai_search.py +++ b/ai_search_with_adi/ai_search/ai_search.py @@ -20,7 +20,6 @@ SearchIndexerDataContainer, SearchIndexerDataSourceConnection, SearchIndexerDataSourceType, - SearchIndexerDataUserAssignedIdentity, OutputFieldMappingEntry, InputFieldMappingEntry, SynonymMap, @@ -29,29 +28,21 @@ ) from azure.core.exceptions import HttpResponseError from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient -from ai_search_with_adi.ai_search.environment import ( - get_fq_blob_connection_string, - get_blob_container_name, - get_custom_skill_function_url, - get_managed_identity_fqname, - get_function_app_authresourceid, -) +from ai_search_with_adi.ai_search.environment import AISearchEnvironment, IdentityType class AISearch(ABC): + """Handles the deployment of the AI search pipeline.""" + def __init__( self, - endpoint: str, - credential, suffix: str | None = None, rebuild: bool | None = False, ): """Initialize the AI search class Args: - endpoint (str): The search endpoint - credential (AzureKeyCredential): The search credential - suffix (str, optional): The suffix for the indexer. Defaults to None. + suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer. rebuild (bool, optional): Whether to rebuild the index. Defaults to False. """ self.indexer_type = None @@ -61,6 +52,7 @@ def __init__( else: self.rebuild = False + # If suffix is None, then it is not a test indexer. Test indexer limits the rate of indexing and turns off the schedule. Useful for testing index changes if suffix is None: self.suffix = "" self.test = False @@ -68,8 +60,14 @@ def __init__( self.suffix = f"-{suffix}-test" self.test = True - self._search_indexer_client = SearchIndexerClient(endpoint, credential) - self._search_index_client = SearchIndexClient(endpoint, credential) + self.environment = AISearchEnvironment(indexer_type=self.indexer_type) + + self._search_indexer_client = SearchIndexerClient( + self.environment.ai_search_endpoint, self.environment.ai_search_credential + ) + self._search_index_client = SearchIndexClient( + self.environment.ai_search_endpoint, self.environment.ai_search_credential + ) @property def indexer_name(self): @@ -94,7 +92,7 @@ def index_name(self): @property def data_source_name(self): """Get the data source name for the indexer.""" - blob_container_name = get_blob_container_name(self.indexer_type) + blob_container_name = self.environment.get_blob_container_name() return f"{blob_container_name}-data-source{self.suffix}" @property @@ -146,16 +144,6 @@ def get_synonym_map_names(self) -> list[str]: """Get the synonym map names for the indexer.""" return [] - def get_user_assigned_managed_identity( - self, - ) -> SearchIndexerDataUserAssignedIdentity: - """Get user assigned managed identity details""" - - user_assigned_identity = SearchIndexerDataUserAssignedIdentity( - user_assigned_identity=get_managed_identity_fqname() - ) - return user_assigned_identity - def get_data_source(self) -> SearchIndexerDataSourceConnection: """Get the data source for the indexer.""" @@ -166,19 +154,21 @@ def get_data_source(self) -> SearchIndexerDataSourceConnection: ) container = SearchIndexerDataContainer( - name=get_blob_container_name(self.indexer_type) + name=self.environment.get_blob_container_name() ) data_source_connection = SearchIndexerDataSourceConnection( name=self.data_source_name, type=SearchIndexerDataSourceType.AZURE_BLOB, - connection_string=get_fq_blob_connection_string(), + connection_string=self.environment.blob_connection_string, container=container, data_change_detection_policy=data_change_detection_policy, data_deletion_detection_policy=data_deletion_detection_policy, - identity=self.get_user_assigned_managed_identity(), ) + if self.environment.identity_type != IdentityType.KEY: + data_source_connection.identity = self.environment.ai_search_identity_id + return data_source_connection def get_pre_embedding_cleaner_skill( @@ -226,17 +216,25 @@ def get_pre_embedding_cleaner_skill( name="Pre Embedding Cleaner Skill", description="Skill to clean the data before sending to embedding", context=context, - uri=get_custom_skill_function_url("pre_embedding_cleaner"), + uri=self.environment.get_custom_skill_function_url("pre_embedding_cleaner"), timeout="PT230S", batch_size=batch_size, degree_of_parallelism=degree_of_parallelism, http_method="POST", inputs=pre_embedding_cleaner_skill_inputs, outputs=pre_embedding_cleaner_skill_outputs, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), ) + if self.environment.identity_type != IdentityType.KEY: + pre_embedding_cleaner_skill.auth_identity = ( + self.environment.ai_search_identity_id + ) + + if self.environment.identity_type == IdentityType.USER_ASSIGNED: + pre_embedding_cleaner_skill.auth_resource_id = ( + self.environment.ai_search_user_assigned_identity + ) + return pre_embedding_cleaner_skill def get_text_split_skill(self, context, source) -> SplitSkill: @@ -294,7 +292,7 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: name="ADI Skill", description="Skill to generate ADI", context="/document", - uri=get_custom_skill_function_url("adi"), + uri=self.environment.get_custom_skill_function_url("adi"), timeout="PT230S", batch_size=batch_size, degree_of_parallelism=degree_of_parallelism, @@ -306,10 +304,16 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: ) ], outputs=output, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), ) + if self.environment.identity_type != IdentityType.KEY: + adi_skill.auth_identity = self.environment.ai_search_identity_id + + if self.environment.identity_type == IdentityType.USER_ASSIGNED: + adi_skill.auth_resource_id = ( + self.environment.ai_search_user_assigned_identity + ) + return adi_skill def get_vector_skill( @@ -368,17 +372,25 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill: name="Key phrase extraction API", description="Skill to extract keyphrases", context=context, - uri=get_custom_skill_function_url("keyphraseextraction"), + uri=self.environment.get_custom_skill_function_url("key_phrase_extraction"), timeout="PT230S", batch_size=batch_size, degree_of_parallelism=degree_of_parallelism, http_method="POST", inputs=keyphrase_extraction_skill_inputs, outputs=keyphrase_extraction__skill_outputs, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), ) + if self.environment.identity_type != IdentityType.KEY: + key_phrase_extraction_skill.auth_identity = ( + self.environment.ai_search_identity_id + ) + + if self.environment.identity_type == IdentityType.USER_ASSIGNED: + key_phrase_extraction_skill.auth_resource_id = ( + self.environment.ai_search_user_assigned_identity + ) + return key_phrase_extraction_skill def get_vector_search(self) -> VectorSearch: diff --git a/ai_search_with_adi/ai_search/deploy.py b/ai_search_with_adi/ai_search/deploy.py index 5e1ffb2..e28a61c 100644 --- a/ai_search_with_adi/ai_search/deploy.py +++ b/ai_search_with_adi/ai_search/deploy.py @@ -1,49 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - import argparse -from ai_search_with_adi.ai_search.environment import ( - get_search_endpoint, - get_managed_identity_id, - get_search_key, - get_key_vault_url, -) -from azure.core.credentials import AzureKeyCredential -from azure.identity import DefaultAzureCredential -from azure.keyvault.secrets import SecretClient from ai_search_with_adi.ai_search.rag_documents import RagDocumentsAISearch -def main(args): - endpoint = get_search_endpoint() - - try: - credential = DefaultAzureCredential( - managed_identity_client_id=get_managed_identity_id() - ) - # initializing key vault client - client = SecretClient(vault_url=get_key_vault_url(), credential=credential) - print("Using managed identity credential") - except Exception as e: - print(e) - credential = AzureKeyCredential(get_search_key(client=client)) - print("Using Azure Key credential") +def deploy_config(arguments: argparse.Namespace): + """Deploy the indexer configuration based on the arguments passed. - if args.indexer_type == "rag": - # Deploy the inquiry index + Args: + arguments (argparse.Namespace): The arguments passed to the script""" + if arguments.indexer_type == "rag": index_config = RagDocumentsAISearch( - endpoint=endpoint, - credential=credential, - suffix=args.suffix, - rebuild=args.rebuild, - enable_page_by_chunking=args.enable_page_chunking, + suffix=arguments.suffix, + rebuild=arguments.rebuild, + enable_page_by_chunking=arguments.enable_page_chunking, ) else: raise ValueError("Invalid Indexer Type") index_config.deploy() - if args.rebuild: + if arguments.rebuild: index_config.reset_indexer() @@ -75,4 +52,4 @@ def main(args): ) args = parser.parse_args() - main(args) + deploy_config(args) diff --git a/ai_search_with_adi/ai_search/rag_documents.py b/ai_search_with_adi/ai_search/rag_documents.py index 8adfe16..8541478 100644 --- a/ai_search_with_adi/ai_search/rag_documents.py +++ b/ai_search_with_adi/ai_search/rag_documents.py @@ -24,7 +24,6 @@ ) from ai_search import AISearch from ai_search_with_adi.ai_search.environment import ( - get_search_embedding_model_dimensions, IndexerType, ) @@ -34,13 +33,17 @@ class RagDocumentsAISearch(AISearch): def __init__( self, - endpoint, - credential, - suffix=None, - rebuild=False, + suffix: str | None = None, + rebuild: bool | None = False, enable_page_by_chunking=False, ): - super().__init__(endpoint, credential, suffix, rebuild) + """Initialize the RagDocumentsAISearch class. This class implements the deployment of the rag document index. + + Args: + suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer. + rebuild (bool, optional): Whether to rebuild the index. Defaults to False. + """ + super().__init__(suffix, rebuild) self.indexer_type = IndexerType.RAG_DOCUMENTS if enable_page_by_chunking is not None: @@ -80,9 +83,7 @@ def get_index_fields(self) -> list[SearchableField]: SearchField( name="ChunkEmbedding", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), - vector_search_dimensions=get_search_embedding_model_dimensions( - self.indexer_type - ), + vector_search_dimensions=self.environment.embedding_model_dimensions, vector_search_profile_name=self.vector_search_profile_name, ), SearchableField( @@ -224,6 +225,8 @@ def get_indexer(self) -> SearchIndexer: Returns: SearchIndexer: The indexer for inquiry document""" + + # Only place on schedule if it is not a test deployment if self.test: schedule = None batch_size = 4 @@ -231,12 +234,17 @@ def get_indexer(self) -> SearchIndexer: schedule = {"interval": "PT15M"} batch_size = 16 + if self.environment.use_private_endpoint: + execution_environment = IndexerExecutionEnvironment.PRIVATE + else: + execution_environment = IndexerExecutionEnvironment.STANDARD + indexer_parameters = IndexingParameters( batch_size=batch_size, configuration=IndexingParametersConfiguration( data_to_extract=BlobIndexerDataToExtract.ALL_METADATA, query_timeout=None, - execution_environment=IndexerExecutionEnvironment.PRIVATE, + execution_environment=execution_environment, fail_on_unprocessable_document=False, fail_on_unsupported_content_type=False, index_storage_metadata_only_for_oversized_documents=True,