From ad8684f9184add86345a492412148f3aa3de900d Mon Sep 17 00:00:00 2001 From: Ben Constable Date: Tue, 10 Sep 2024 11:07:27 +0100 Subject: [PATCH] Update some of the deployment scripts --- ai_search_with_adi/ai_search/ai_search.py | 315 ++++-------------- ai_search_with_adi/ai_search/deploy.py | 7 +- ai_search_with_adi/ai_search/environment.py | 17 +- .../{inquiry_document.py => rag_documents.py} | 122 ++----- .../function_apps/indexer/adi_2_ai_search.py | 92 ++--- 5 files changed, 157 insertions(+), 396 deletions(-) rename ai_search_with_adi/ai_search/{inquiry_document.py => rag_documents.py} (63%) diff --git a/ai_search_with_adi/ai_search/ai_search.py b/ai_search_with_adi/ai_search/ai_search.py index 6ababd7..d2c9ba5 100644 --- a/ai_search_with_adi/ai_search/ai_search.py +++ b/ai_search_with_adi/ai_search/ai_search.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +import logging from abc import ABC, abstractmethod from azure.search.documents.indexes.models import ( SearchIndex, @@ -12,8 +12,9 @@ NativeBlobSoftDeleteDeletionDetectionPolicy, HighWaterMarkChangeDetectionPolicy, WebApiSkill, - CustomVectorizer, - CustomWebApiParameters, + AzureOpenAIEmbeddingSkill, + AzureOpenAIVectorizer, + AzureOpenAIParameters, SearchIndexer, SearchIndexerSkillset, SearchIndexerDataContainer, @@ -23,11 +24,8 @@ OutputFieldMappingEntry, InputFieldMappingEntry, SynonymMap, - DocumentExtractionSkill, - OcrSkill, - MergeSkill, - ConditionalSkill, SplitSkill, + SearchIndexerIndexProjections, ) from azure.core.exceptions import HttpResponseError from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient @@ -37,7 +35,6 @@ get_custom_skill_function_url, get_managed_identity_fqname, get_function_app_authresourceid, - IndexerType, ) @@ -53,7 +50,10 @@ def __init__( Args: endpoint (str): The search endpoint - credential (AzureKeyCredential): The search credential""" + credential (AzureKeyCredential): The search credential + suffix (str, optional): The suffix for the indexer. Defaults to None. + rebuild (bool, optional): Whether to rebuild the index. Defaults to False. + """ self.indexer_type = None if rebuild is not None: @@ -100,20 +100,18 @@ def data_source_name(self): @property def vector_search_profile_name(self): """Get the vector search profile name for the indexer.""" - return ( - f"{str(self.indexer_type.value)}-compass-vector-search-profile{self.suffix}" - ) + return f"{str(self.indexer_type.value)}-vector-search-profile{self.suffix}" @property def vectorizer_name(self): """Get the vectorizer name.""" - return f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}" + return f"{str(self.indexer_type.value)}-vectorizer{self.suffix}" @property def algorithm_name(self): - """Gtt the algorithm name""" + """Get the algorithm name""" - return f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}" + return f"{str(self.indexer_type.value)}-algorithm{self.suffix}" @abstractmethod def get_index_fields(self) -> list[SearchableField]: @@ -130,18 +128,21 @@ def get_semantic_search(self) -> SemanticSearch: SemanticSearch: The semantic search configuration""" @abstractmethod - def get_skills(self): - """Get the skillset for the indexer.""" + def get_skills(self) -> list: + """Get the skillset for the indexer. + + Returns: + list: The skillsets used in the indexer""" @abstractmethod def get_indexer(self) -> SearchIndexer: """Get the indexer for the indexer.""" - def get_index_projections(self): + @abstractmethod + def get_index_projections(self) -> SearchIndexerIndexProjections: """Get the index projections for the indexer.""" - return None - def get_synonym_map_names(self): + def get_synonym_map_names(self) -> list[str]: """Get the synonym map names for the indexer.""" return [] @@ -158,12 +159,7 @@ def get_user_assigned_managed_identity( def get_data_source(self) -> SearchIndexerDataSourceConnection: """Get the data source for the indexer.""" - if self.indexer_type == IndexerType.BUSINESS_GLOSSARY: - data_deletion_detection_policy = None - else: - data_deletion_detection_policy = ( - NativeBlobSoftDeleteDeletionDetectionPolicy() - ) + data_deletion_detection_policy = NativeBlobSoftDeleteDeletionDetectionPolicy() data_change_detection_policy = HighWaterMarkChangeDetectionPolicy( high_water_mark_column_name="metadata_storage_last_modified" @@ -185,52 +181,6 @@ def get_data_source(self) -> SearchIndexerDataSourceConnection: return data_source_connection - def get_compass_vector_custom_skill( - self, context, source, target_name="vector" - ) -> WebApiSkill: - """Get the custom skill for compass. - - Args: - ----- - context (str): The context of the skill - source (str): The source of the skill - target_name (str): The target name of the skill - - Returns: - -------- - WebApiSkill: The custom skill for compass""" - - if self.test: - batch_size = 2 - degree_of_parallelism = 2 - else: - batch_size = 4 - degree_of_parallelism = 8 - - embedding_skill_inputs = [ - InputFieldMappingEntry(name="text", source=source), - ] - embedding_skill_outputs = [ - OutputFieldMappingEntry(name="vector", target_name=target_name) - ] - # Limit the number of documents to be processed in parallel to avoid timing out on compass api - embedding_skill = WebApiSkill( - name="Compass Connector API", - description="Skill to generate embeddings via compass API connector", - context=context, - uri=get_custom_skill_function_url("compass"), - timeout="PT230S", - batch_size=batch_size, - degree_of_parallelism=degree_of_parallelism, - http_method="POST", - inputs=embedding_skill_inputs, - outputs=embedding_skill_outputs, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), - ) - - return embedding_skill - def get_pre_embedding_cleaner_skill( self, context, source, chunk_by_page=False, target_name="cleaned_chunk" ) -> WebApiSkill: @@ -260,13 +210,15 @@ def get_pre_embedding_cleaner_skill( pre_embedding_cleaner_skill_outputs = [ OutputFieldMappingEntry(name="cleaned_chunk", target_name=target_name), OutputFieldMappingEntry(name="chunk", target_name="chunk"), - OutputFieldMappingEntry(name="section", target_name="eachsection"), + OutputFieldMappingEntry(name="section", target_name="section"), ] if chunk_by_page: pre_embedding_cleaner_skill_outputs.extend( [ - OutputFieldMappingEntry(name="page_number", target_name="page_no"), + OutputFieldMappingEntry( + name="page_number", target_name="page_number" + ), ] ) @@ -313,7 +265,6 @@ def get_text_split_skill(self, context, source) -> SplitSkill: return text_split_skill - def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: """Get the custom skill for adi. @@ -361,45 +312,32 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: return adi_skill - def get_excel_skill(self) -> WebApiSkill: - """Get the custom skill for adi. + def get_vector_skill( + self, context, source, target_name="vector" + ) -> AzureOpenAIEmbeddingSkill: + """Get the vector skill for the indexer. Returns: - -------- - WebApiSkill: The custom skill for adi""" + AzureOpenAIEmbeddingSkill: The vector skill for the indexer""" - if self.test: - batch_size = 1 - degree_of_parallelism = 4 - else: - batch_size = 1 - degree_of_parallelism = 8 - - output = [ - OutputFieldMappingEntry(name="extracted_content", target_name="pages") + embedding_skill_inputs = [ + InputFieldMappingEntry(name="text", source=source), + ] + embedding_skill_outputs = [ + OutputFieldMappingEntry(name="vector", target_name=target_name) ] - xlsx_skill = WebApiSkill( - name="XLSX Skill", - description="Skill to generate Markdown from XLSX", - context="/document", - uri=get_custom_skill_function_url("xlsx"), - timeout="PT230S", - batch_size=batch_size, - degree_of_parallelism=degree_of_parallelism, - http_method="POST", - http_headers={}, - inputs=[ - InputFieldMappingEntry( - name="source", source="/document/metadata_storage_path" - ) - ], - outputs=output, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), + vector_skill = AzureOpenAIEmbeddingSkill( + name="Vector Skill", + description="Skill to generate embeddings", + context=context, + deployment_id="0", + model_name="text-embedding-3-large", + inputs=embedding_skill_inputs, + outputs=embedding_skill_outputs, ) - return xlsx_skill + return vector_skill def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill: """Get the key phrase extraction skill. @@ -443,126 +381,7 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill: return key_phrase_extraction_skill - def get_document_extraction_skill(self, context, source) -> DocumentExtractionSkill: - """Get the document extraction utility skill. - - Args: - ----- - context (str): The context of the skill - source (str): The source of the skill - - Returns: - -------- - DocumentExtractionSkill: The document extraction utility skill""" - - doc_extraction_skill = DocumentExtractionSkill( - description="Extraction skill to extract content from office docs like excel, ppt, doc etc", - context=context, - inputs=[InputFieldMappingEntry(name="file_data", source=source)], - outputs=[ - OutputFieldMappingEntry( - name="content", target_name="extracted_content" - ), - OutputFieldMappingEntry( - name="normalized_images", target_name="extracted_normalized_images" - ), - ], - ) - - return doc_extraction_skill - - def get_ocr_skill(self, context, source) -> OcrSkill: - """Get the ocr utility skill - Args: - ----- - context (str): The context of the skill - source (str): The source of the skill - - Returns: - -------- - OcrSkill: The ocr skill""" - - if self.test: - batch_size = 2 - degree_of_parallelism = 2 - else: - batch_size = 2 - degree_of_parallelism = 2 - - ocr_skill_inputs = [ - InputFieldMappingEntry(name="image", source=source), - ] - ocr__skill_outputs = [OutputFieldMappingEntry(name="text", target_name="text")] - ocr_skill = WebApiSkill( - name="ocr API", - description="Skill to extract text from images", - context=context, - uri=get_custom_skill_function_url("ocr"), - timeout="PT230S", - batch_size=batch_size, - degree_of_parallelism=degree_of_parallelism, - http_method="POST", - inputs=ocr_skill_inputs, - outputs=ocr__skill_outputs, - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), - ) - - return ocr_skill - - def get_merge_skill(self, context, source) -> MergeSkill: - """Get the merge - Args: - ----- - context (str): The context of the skill - source (array): The source of the skill - - Returns: - -------- - mergeSkill: The merge skill""" - - merge_skill = MergeSkill( - description="Merge skill for combining OCR'd and regular text", - context=context, - inputs=[ - InputFieldMappingEntry(name="text", source=source[0]), - InputFieldMappingEntry(name="itemsToInsert", source=source[1]), - InputFieldMappingEntry(name="offsets", source=source[2]), - ], - outputs=[ - OutputFieldMappingEntry(name="mergedText", target_name="merged_content") - ], - ) - - return merge_skill - - def get_conditional_skill(self, context, source) -> ConditionalSkill: - """Get the merge - Args: - ----- - context (str): The context of the skill - source (array): The source of the skill - - Returns: - -------- - ConditionalSkill: The conditional skill""" - - conditional_skill = ConditionalSkill( - description="Select between OCR and Document Extraction output", - context=context, - inputs=[ - InputFieldMappingEntry(name="condition", source=source[0]), - InputFieldMappingEntry(name="whenTrue", source=source[1]), - InputFieldMappingEntry(name="whenFalse", source=source[2]), - ], - outputs=[ - OutputFieldMappingEntry(name="output", target_name="updated_content") - ], - ) - - return conditional_skill - - def get_compass_vector_search(self) -> VectorSearch: + def get_vector_search(self) -> VectorSearch: """Get the vector search configuration for compass. Args: @@ -584,13 +403,9 @@ def get_compass_vector_search(self) -> VectorSearch: ) ], vectorizers=[ - CustomVectorizer( + AzureOpenAIVectorizer( name=self.vectorizer_name, - custom_web_api_parameters=CustomWebApiParameters( - uri=get_custom_skill_function_url("compass"), - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), - ), + azure_open_ai_parameters=AzureOpenAIParameters(), ), ], ) @@ -601,7 +416,7 @@ def deploy_index(self): """This function deploys index""" index_fields = self.get_index_fields() - vector_search = self.get_compass_vector_search() + vector_search = self.get_vector_search() semantic_search = self.get_semantic_search() index = SearchIndex( name=self.index_name, @@ -613,7 +428,7 @@ def deploy_index(self): self._search_index_client.delete_index(self.index_name) self._search_index_client.create_or_update_index(index) - print(f"{index.name} created") + logging.info("%s index created", index.name) def deploy_skillset(self): """This function deploys the skillset.""" @@ -628,7 +443,8 @@ def deploy_skillset(self): ) self._search_indexer_client.create_or_update_skillset(skillset) - print(f"{skillset.name} created") + + logging.info("%s skillset created", skillset.name) def deploy_data_source(self): """This function deploys the data source.""" @@ -638,9 +454,7 @@ def deploy_data_source(self): data_source ) - print(f"Data source '{result.name}' created or updated") - - return result + logging.info("%s data source created", result.name) def deploy_indexer(self): """This function deploys the indexer.""" @@ -648,33 +462,34 @@ def deploy_indexer(self): result = self._search_indexer_client.create_or_update_indexer(indexer) - print(f"Indexer '{result.name}' created or updated") - - return result + logging.info("%s indexer created", result.name) def run_indexer(self): """This function runs the indexer.""" self._search_indexer_client.run_indexer(self.indexer_name) - print( - f"{self.indexer_name} is running. If queries return no results, please wait a bit and try again." + logging.info( + "%s is running. If queries return no results, please wait a bit and try again.", + self.indexer_name, ) def reset_indexer(self): """This function runs the indexer.""" self._search_indexer_client.reset_indexer(self.indexer_name) - print(f"{self.indexer_name} reset.") + logging.info("%s reset.", self.indexer_name) + + def deploy_synonym_map(self): + """This function deploys the synonym map.""" - def deploy_synonym_map(self) -> list[SearchableField]: synonym_maps = self.get_synonym_map_names() if len(synonym_maps) > 0: for synonym_map in synonym_maps: try: synonym_map = SynonymMap(name=synonym_map, synonyms="") - self._search_index_client.create_synonym_map(synonym_map) - except HttpResponseError: - print("Unable to deploy synonym map as it already exists.") + self._search_index_client.create_or_update_synonym_map(synonym_map) + except HttpResponseError as e: + logging.error("Unable to deploy synonym map. %s", e) def deploy(self): """This function deploys the whole AI search pipeline.""" @@ -684,4 +499,4 @@ def deploy(self): self.deploy_skillset() self.deploy_indexer() - print(f"{str(self.indexer_type.value)} deployed") + logging.info("%s setup deployed", self.indexer_type.value) diff --git a/ai_search_with_adi/ai_search/deploy.py b/ai_search_with_adi/ai_search/deploy.py index d533340..5e1ffb2 100644 --- a/ai_search_with_adi/ai_search/deploy.py +++ b/ai_search_with_adi/ai_search/deploy.py @@ -11,7 +11,8 @@ from azure.core.credentials import AzureKeyCredential from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient -from inquiry_document import InquiryDocumentAISearch +from ai_search_with_adi.ai_search.rag_documents import RagDocumentsAISearch + def main(args): endpoint = get_search_endpoint() @@ -28,9 +29,9 @@ def main(args): credential = AzureKeyCredential(get_search_key(client=client)) print("Using Azure Key credential") - if args.indexer_type == "inquiry": + if args.indexer_type == "rag": # Deploy the inquiry index - index_config = InquiryDocumentAISearch( + index_config = RagDocumentsAISearch( endpoint=endpoint, credential=credential, suffix=args.suffix, diff --git a/ai_search_with_adi/ai_search/environment.py b/ai_search_with_adi/ai_search/environment.py index a17d3a1..b806fe6 100644 --- a/ai_search_with_adi/ai_search/environment.py +++ b/ai_search_with_adi/ai_search/environment.py @@ -12,17 +12,17 @@ class IndexerType(Enum): """The type of the indexer""" - INQUIRY_DOCUMENT = "inquiry-document" - SUMMARY_DOCUMENT = "summary-document" - BUSINESS_GLOSSARY = "business-glossary" + RAG_DOCUMENTS = "rag-documents" + # key vault -def get_key_vault_url() ->str: +def get_key_vault_url() -> str: """ This function returns key vault url """ return os.environ.get("KeyVault__Url") + # managed identity id def get_managed_identity_id() -> str: """ @@ -52,12 +52,14 @@ def get_function_app_end_point() -> str: """ return os.environ.get("FunctionApp__Endpoint") + def get_function_app_key() -> str: """ This function returns function app key """ return os.environ.get("FunctionApp__Key") + def get_function_app_compass_function() -> str: """ This function returns function app compass function name @@ -119,10 +121,13 @@ def get_search_key(client) -> str: """ This function returns azure ai search service admin key """ - search_service_key_secret_name = str(os.environ.get("AIService__AzureSearchOptions__name")) + "-PrimaryKey" + search_service_key_secret_name = ( + str(os.environ.get("AIService__AzureSearchOptions__name")) + "-PrimaryKey" + ) retrieved_secret = client.get_secret(search_service_key_secret_name) return retrieved_secret.value + def get_search_key_secret() -> str: """ This function returns azure ai search service admin key @@ -143,12 +148,14 @@ def get_search_embedding_model_dimensions(indexer_type: IndexerType) -> str: f"AIService__AzureSearchOptions__{normalised_indexer_type}__EmbeddingDimensions" ) + def get_blob_connection_string() -> str: """ This function returns azure blob storage connection string """ return os.environ.get("StorageAccount__ConnectionString") + def get_fq_blob_connection_string() -> str: """ This function returns azure blob storage connection string diff --git a/ai_search_with_adi/ai_search/inquiry_document.py b/ai_search_with_adi/ai_search/rag_documents.py similarity index 63% rename from ai_search_with_adi/ai_search/inquiry_document.py rename to ai_search_with_adi/ai_search/rag_documents.py index 36a55a4..8adfe16 100644 --- a/ai_search_with_adi/ai_search/inquiry_document.py +++ b/ai_search_with_adi/ai_search/rag_documents.py @@ -29,8 +29,8 @@ ) -class InquiryDocumentAISearch(AISearch): - """This class is used to deploy the inquiry document index.""" +class RagDocumentsAISearch(AISearch): + """This class is used to deploy the rag document index.""" def __init__( self, @@ -42,40 +42,12 @@ def __init__( ): super().__init__(endpoint, credential, suffix, rebuild) - self.indexer_type = IndexerType.INQUIRY_DOCUMENT + self.indexer_type = IndexerType.RAG_DOCUMENTS if enable_page_by_chunking is not None: self.enable_page_by_chunking = enable_page_by_chunking else: self.enable_page_by_chunking = False - @property - def index_name(self): - """Get the index name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" - return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-index{self.suffix}" - - @property - def vector_search_profile_name(self): - """Get the vector search profile name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" - return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-compass-vector-search-profile{self.suffix}" - - @property - def vectorizer_name(self): - """Get the vectorizer name. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" - return ( - f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-compass-vectorizer{self.suffix}" - ) - - @property - def algorithm_name(self): - """Gtt the algorithm name. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" - - return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-hnsw-algorithm{self.suffix}" - - @property - def semantic_config_name(self): - """Get the semantic config name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" - return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-semantic-config{self.suffix}" - def get_index_fields(self) -> list[SearchableField]: """This function returns the index fields for inquiry document. @@ -85,42 +57,28 @@ def get_index_fields(self) -> list[SearchableField]: fields = [ SimpleField(name="Id", type=SearchFieldDataType.String, filterable=True), SearchableField( - name="Field1", type=SearchFieldDataType.String, filterable=True - ), - SearchableField( - name="Field2", - type=SearchFieldDataType.String, - sortable=True, - filterable=True, - facetable=True, - ), - SearchableField( - name="Field3", - type=SearchFieldDataType.String, - sortable=True, - filterable=True, - facetable=True, + name="Title", type=SearchFieldDataType.String, filterable=True ), SearchableField( - name="Field4", + name="ChunkId", type=SearchFieldDataType.String, key=True, - analyzer_name="a1", + analyzer_name="keyword", ), SearchableField( - name="Field5", + name="Chunk", type=SearchFieldDataType.String, sortable=False, filterable=False, facetable=False, ), SearchableField( - name="Field6", + name="Sections", type=SearchFieldDataType.String, collection=True, ), SearchField( - name="EmbeddingField", + name="ChunkEmbedding", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), vector_search_dimensions=get_search_embedding_model_dimensions( self.indexer_type @@ -128,17 +86,10 @@ def get_index_fields(self) -> list[SearchableField]: vector_search_profile_name=self.vector_search_profile_name, ), SearchableField( - name="Field7", type=SearchFieldDataType.String, collection=True + name="Keywords", type=SearchFieldDataType.String, collection=True ), SearchableField( - name="Field8", - type=SearchFieldDataType.String, - sortable=True, - filterable=True, - facetable=True, - ), - SearchableField( - name="Field9", + name="SourceUri", type=SearchFieldDataType.String, sortable=True, filterable=True, @@ -150,7 +101,7 @@ def get_index_fields(self) -> list[SearchableField]: fields.extend( [ SearchableField( - name="Field10", + name="PageNumber", type=SearchFieldDataType.Int64, sortable=True, filterable=True, @@ -170,11 +121,11 @@ def get_semantic_search(self) -> SemanticSearch: semantic_config = SemanticConfiguration( name=self.semantic_config_name, prioritized_fields=SemanticPrioritizedFields( - title_field=SemanticField(field_name="Field1"), - content_fields=[SemanticField(field_name="Field2")], + title_field=SemanticField(field_name="Title"), + content_fields=[SemanticField(field_name="Chunk")], keywords_fields=[ - SemanticField(field_name="Field3"), - SemanticField(field_name="Field4"), + SemanticField(field_name="Keywords"), + SemanticField(field_name="Sections"), ], ), ) @@ -183,27 +134,27 @@ def get_semantic_search(self) -> SemanticSearch: return semantic_search - def get_skills(self): - """This function returns the skills for inquiry document""" + def get_skills(self) -> list: + """Get the skillset for the indexer. - adi_skill = self.get_adi_skill(self.enable_page_by_chunking) + Returns: + list: The skillsets used in the indexer""" + adi_skill = self.get_adi_skill(self.enable_page_by_chunking) text_split_skill = self.get_text_split_skill( "/document", "/document/extracted_content/content" ) - pre_embedding_cleaner_skill = self.get_pre_embedding_cleaner_skill( "/document/pages/*", "/document/pages/*", self.enable_page_by_chunking ) - key_phrase_extraction_skill = self.get_key_phrase_extraction_skill( "/document/pages/*", "/document/pages/*/cleaned_chunk" ) - embedding_skill = self.get_compass_vector_custom_skill( + embedding_skill = self.get_vector_skill( "/document/pages/*", "/document/pages/*/cleaned_chunk" ) @@ -233,19 +184,13 @@ def get_index_projections(self) -> SearchIndexerIndexProjections: name="ChunkEmbedding", source="/document/pages/*/vector", ), - InputFieldMappingEntry(name="Field1", source="/document/Field1"), - InputFieldMappingEntry(name="Field2", source="/document/Field2"), - InputFieldMappingEntry(name="Field3", source="/document/Field3"), - InputFieldMappingEntry(name="Field4", source="/document/Field4"), - InputFieldMappingEntry( - name="Field5", source="/document/pages/*/Field5" - ), + InputFieldMappingEntry(name="Title", source="/document/Title"), + InputFieldMappingEntry(name="SourceUri", source="/document/SourceUri"), InputFieldMappingEntry( - name="Field6", - source="/document/Field6", + name="Keywords", source="/document/pages/*/keywords" ), InputFieldMappingEntry( - name="Field7", source="/document/pages/*/Field7" + name="Sections", source="/document/pages/*/sections" ), ] @@ -253,7 +198,7 @@ def get_index_projections(self) -> SearchIndexerIndexProjections: mappings.extend( [ InputFieldMappingEntry( - name="Field8", source="/document/pages/*/Field8" + name="PageNumber", source="/document/pages/*/page_number" ) ] ) @@ -295,7 +240,7 @@ def get_indexer(self) -> SearchIndexer: fail_on_unprocessable_document=False, fail_on_unsupported_content_type=False, index_storage_metadata_only_for_oversized_documents=True, - indexed_file_name_extensions=".pdf,.pptx,.docx", + indexed_file_name_extensions=".pdf,.pptx,.docx,.xlsx,.txt", ), max_failed_items=5, ) @@ -311,16 +256,9 @@ def get_indexer(self) -> SearchIndexer: FieldMapping( source_field_name="metadata_storage_name", target_field_name="Title" ), - FieldMapping(source_field_name="Field1", target_field_name="Field1"), - FieldMapping( - source_field_name="Field2", target_field_name="Field2" - ), - FieldMapping( - source_field_name="Field3", target_field_name="Field3" - ), FieldMapping( - source_field_name="Field4", - target_field_name="Field4", + source_field_name="metadata_storage_path", + target_field_name="SourceUri", ), ], parameters=indexer_parameters, diff --git a/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py b/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py index a477a85..ae0474a 100644 --- a/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py +++ b/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py @@ -11,13 +11,13 @@ import fitz from PIL import Image import io -import aiohttp import logging from common.storage_account import StorageAccountHelper import concurrent.futures import json from openai import AzureOpenAI + def crop_image_from_pdf_page(pdf_path, page_number, bounding_box): """ Crops a region from a given page in a PDF and returns it as an image. @@ -41,7 +41,9 @@ def crop_image_from_pdf_page(pdf_path, page_number, bounding_box): return img -def clean_adi_markdown(markdown_text: str, page_no:int,remove_irrelevant_figures=False): +def clean_adi_markdown( + markdown_text: str, page_no: int, remove_irrelevant_figures=False +): """Clean Markdown text extracted by the Azure Document Intelligence service. Args: @@ -73,11 +75,10 @@ def clean_adi_markdown(markdown_text: str, page_no:int,remove_irrelevant_figures comment_patterns = r"||" cleaned_text = re.sub(comment_patterns, "", markdown_text, flags=re.DOTALL) - combined_pattern = r'(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n' + combined_pattern = r"(.*?)\n===|\n## ?(.*?)\n|\n### ?(.*?)\n" doc_metadata = re.findall(combined_pattern, cleaned_text, re.DOTALL) doc_metadata = [match for group in doc_metadata for match in group if match] - if remove_irrelevant_figures: # Remove irrelevant figures irrelevant_figure_pattern = ( @@ -89,12 +90,12 @@ def clean_adi_markdown(markdown_text: str, page_no:int,remove_irrelevant_figures # Replace ':selected:' with a new line cleaned_text = re.sub(r":(selected|unselected):", "\n", cleaned_text) - output_dict['content'] = cleaned_text - output_dict['section'] = doc_metadata + output_dict["content"] = cleaned_text + output_dict["sections"] = doc_metadata # add page number when chunk by page is enabled - if page_no> -1: - output_dict['page_number'] = page_no + if page_no > -1: + output_dict["page_number"] = page_no return output_dict @@ -152,60 +153,59 @@ async def understand_image_with_gptv(image_base64, caption): deployment_name = os.environ["AzureAI__GPT4V_Deployment"] api_base = os.environ["AzureAI__GPT4V_APIbase"] - client = AzureOpenAI( - api_key=api_key, + api_key=api_key, api_version=api_version, - base_url=f"{api_base}/openai/deployments/{deployment_name}" + base_url=f"{api_base}/openai/deployments/{deployment_name}", ) # We send both image caption and the image body to GPTv for better understanding if caption != "": response = client.chat.completions.create( - model=deployment_name, - messages=[ - { "role": "system", "content": "You are a helpful assistant." }, - { "role": "user", "content": [ - { - "type": "text", - "text": f"Describe this image (note: it has image caption: {caption}):" + model=deployment_name, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": f"Describe this image (note: it has image caption: {caption}):", }, - { + { "type": "image_base64", - "image_base64": { - "image": image_base64 - } - } - ] } - ], - max_tokens=MAX_TOKENS - ) + "image_base64": {"image": image_base64}, + }, + ], + }, + ], + max_tokens=MAX_TOKENS, + ) else: response = client.chat.completions.create( model=deployment_name, messages=[ - { "role": "system", "content": "You are a helpful assistant." }, - { "role": "user", "content": [ - { - "type": "text", - "text": "Describe this image:" - }, - { - "type": "image_base64", - "image_base64": { - "image": image_base64 - } - } - ] } + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image:"}, + { + "type": "image_base64", + "image_base64": {"image": image_base64}, + }, + ], + }, ], - max_tokens=MAX_TOKENS + max_tokens=MAX_TOKENS, ) img_description = response.choices[0].message.content - + return img_description + def pil_image_to_base64(image, image_format="JPEG"): """ Converts a PIL image to a base64-encoded string. @@ -293,10 +293,10 @@ def create_page_wise_content(result: AnalyzeResult) -> list: page.spans[0]["offset"] : page.spans[0]["offset"] + page.spans[0]["length"] ] page_wise_content.append(page_content) - page_number+=1 + page_number += 1 page_numbers.append(page_number) - return page_wise_content,page_numbers + return page_wise_content, page_numbers async def analyse_document(file_path: str) -> AnalyzeResult: @@ -431,7 +431,7 @@ async def process_adi_2_ai_search(record: dict, chunk_by_page: bool = False) -> try: if chunk_by_page: cleaned_result = [] - markdown_content,page_no = create_page_wise_content(result) + markdown_content, page_no = create_page_wise_content(result) tasks = [ process_figures_from_extracted_content( temp_file_path, page_content, result.figures, page_number=idx @@ -455,7 +455,7 @@ async def process_adi_2_ai_search(record: dict, chunk_by_page: bool = False) -> temp_file_path, markdown_content, result.figures ) cleaned_result = clean_adi_markdown( - content_with_figures, page_no=-1,remove_irrelevant_figures=False + content_with_figures, page_no=-1, remove_irrelevant_figures=False ) except Exception as e: logging.error(e)