diff --git a/ai_search_with_adi/ai_search.py b/ai_search_with_adi/ai_search/ai_search.py similarity index 91% rename from ai_search_with_adi/ai_search.py rename to ai_search_with_adi/ai_search/ai_search.py index 7573055..6ababd7 100644 --- a/ai_search_with_adi/ai_search.py +++ b/ai_search_with_adi/ai_search/ai_search.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from abc import ABC, abstractmethod from azure.search.documents.indexes.models import ( SearchIndex, @@ -28,7 +31,7 @@ ) from azure.core.exceptions import HttpResponseError from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient -from environment import ( +from ai_search_with_adi.ai_search.environment import ( get_fq_blob_connection_string, get_blob_container_name, get_custom_skill_function_url, @@ -70,31 +73,48 @@ def __init__( @property def indexer_name(self): + """Get the indexer name for the indexer.""" return f"{str(self.indexer_type.value)}-indexer{self.suffix}" @property def skillset_name(self): + """Get the skillset name for the indexer.""" return f"{str(self.indexer_type.value)}-skillset{self.suffix}" @property def semantic_config_name(self): + """Get the semantic config name for the indexer.""" return f"{str(self.indexer_type.value)}-semantic-config{self.suffix}" @property def index_name(self): + """Get the index name for the indexer.""" return f"{str(self.indexer_type.value)}-index{self.suffix}" @property def data_source_name(self): + """Get the data source name for the indexer.""" blob_container_name = get_blob_container_name(self.indexer_type) return f"{blob_container_name}-data-source{self.suffix}" @property def vector_search_profile_name(self): + """Get the vector search profile name for the indexer.""" return ( f"{str(self.indexer_type.value)}-compass-vector-search-profile{self.suffix}" ) + @property + def vectorizer_name(self): + """Get the vectorizer name.""" + return f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}" + + @property + def algorithm_name(self): + """Gtt the algorithm name""" + + return f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}" + @abstractmethod def get_index_fields(self) -> list[SearchableField]: """Get the index fields for the indexer. @@ -122,6 +142,7 @@ def get_index_projections(self): return None def get_synonym_map_names(self): + """Get the synonym map names for the indexer.""" return [] def get_user_assigned_managed_identity( @@ -292,67 +313,7 @@ def get_text_split_skill(self, context, source) -> SplitSkill: return text_split_skill - def get_custom_text_split_skill( - self, - context, - source, - text_split_mode="semantic", - maximum_page_length=1000, - separator=" ", - initial_threshold=0.7, - appending_threshold=0.6, - merging_threshold=0.6, - ) -> WebApiSkill: - """Get the custom skill for text split. - - Args: - ----- - context (str): The context of the skill - inputs (List[InputFieldMappingEntry]): The inputs of the skill - outputs (List[OutputFieldMappingEntry]): The outputs of the skill - - Returns: - -------- - WebApiSkill: The custom skill for text split""" - - if self.test: - batch_size = 2 - degree_of_parallelism = 2 - else: - batch_size = 2 - degree_of_parallelism = 6 - - text_split_skill_inputs = [ - InputFieldMappingEntry(name="text", source=source), - ] - - headers = { - "text_split_mode": text_split_mode, - "maximum_page_length": maximum_page_length, - "separator": separator, - "initial_threshold": initial_threshold, - "appending_threshold": appending_threshold, - "merging_threshold": merging_threshold, - } - - text_split_skill = WebApiSkill( - name="Text Split Skill", - description="Skill to split the text before sending to embedding", - context=context, - uri=get_custom_skill_function_url("split"), - timeout="PT230S", - batch_size=batch_size, - degree_of_parallelism=degree_of_parallelism, - http_method="POST", - http_headers=headers, - inputs=text_split_skill_inputs, - outputs=[OutputFieldMappingEntry(name="chunks", target_name="pages")], - auth_resource_id=get_function_app_authresourceid(), - auth_identity=self.get_user_assigned_managed_identity(), - ) - - return text_split_skill - + def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: """Get the custom skill for adi. @@ -400,6 +361,46 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill: return adi_skill + def get_excel_skill(self) -> WebApiSkill: + """Get the custom skill for adi. + + Returns: + -------- + WebApiSkill: The custom skill for adi""" + + if self.test: + batch_size = 1 + degree_of_parallelism = 4 + else: + batch_size = 1 + degree_of_parallelism = 8 + + output = [ + OutputFieldMappingEntry(name="extracted_content", target_name="pages") + ] + + xlsx_skill = WebApiSkill( + name="XLSX Skill", + description="Skill to generate Markdown from XLSX", + context="/document", + uri=get_custom_skill_function_url("xlsx"), + timeout="PT230S", + batch_size=batch_size, + degree_of_parallelism=degree_of_parallelism, + http_method="POST", + http_headers={}, + inputs=[ + InputFieldMappingEntry( + name="source", source="/document/metadata_storage_path" + ) + ], + outputs=output, + auth_resource_id=get_function_app_authresourceid(), + auth_identity=self.get_user_assigned_managed_identity(), + ) + + return xlsx_skill + def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill: """Get the key phrase extraction skill. @@ -570,25 +571,21 @@ def get_compass_vector_search(self) -> VectorSearch: Returns: VectorSearch: The vector search configuration """ - vectorizer_name = ( - f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}" - ) - algorithim_name = f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}" vector_search = VectorSearch( algorithms=[ - HnswAlgorithmConfiguration(name=algorithim_name), + HnswAlgorithmConfiguration(name=self.algorithm_name), ], profiles=[ VectorSearchProfile( name=self.vector_search_profile_name, - algorithm_configuration_name=algorithim_name, - vectorizer=vectorizer_name, + algorithm_configuration_name=self.algorithm_name, + vectorizer=self.vectorizer_name, ) ], vectorizers=[ CustomVectorizer( - name=vectorizer_name, + name=self.vectorizer_name, custom_web_api_parameters=CustomWebApiParameters( uri=get_custom_skill_function_url("compass"), auth_resource_id=get_function_app_authresourceid(), diff --git a/ai_search_with_adi/deploy.py b/ai_search_with_adi/ai_search/deploy.py similarity index 67% rename from ai_search_with_adi/deploy.py rename to ai_search_with_adi/ai_search/deploy.py index 1b2190b..d533340 100644 --- a/ai_search_with_adi/deploy.py +++ b/ai_search_with_adi/ai_search/deploy.py @@ -1,35 +1,45 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import argparse -from environment import get_search_endpoint, get_managed_identity_id, get_search_key,get_key_vault_url +from ai_search_with_adi.ai_search.environment import ( + get_search_endpoint, + get_managed_identity_id, + get_search_key, + get_key_vault_url, +) from azure.core.credentials import AzureKeyCredential -from azure.identity import DefaultAzureCredential,ManagedIdentityCredential,EnvironmentCredential +from azure.identity import DefaultAzureCredential from azure.keyvault.secrets import SecretClient from inquiry_document import InquiryDocumentAISearch - def main(args): endpoint = get_search_endpoint() try: - credential = DefaultAzureCredential(managed_identity_client_id =get_managed_identity_id()) + credential = DefaultAzureCredential( + managed_identity_client_id=get_managed_identity_id() + ) # initializing key vault client client = SecretClient(vault_url=get_key_vault_url(), credential=credential) print("Using managed identity credential") except Exception as e: print(e) - credential = ( - AzureKeyCredential(get_search_key(client=client)) - ) + credential = AzureKeyCredential(get_search_key(client=client)) print("Using Azure Key credential") if args.indexer_type == "inquiry": # Deploy the inquiry index index_config = InquiryDocumentAISearch( - endpoint=endpoint, - credential=credential, + endpoint=endpoint, + credential=credential, suffix=args.suffix, - rebuild=args.rebuild, - enable_page_by_chunking=args.enable_page_chunking + rebuild=args.rebuild, + enable_page_by_chunking=args.enable_page_chunking, ) + else: + raise ValueError("Invalid Indexer Type") + index_config.deploy() if args.rebuild: @@ -42,7 +52,7 @@ def main(args): "--indexer_type", type=str, required=True, - help="Type of Indexer want to deploy. inquiry/summary/glossary", + help="Type of Indexer want to deploy.", ) parser.add_argument( "--rebuild", diff --git a/ai_search_with_adi/environment.py b/ai_search_with_adi/ai_search/environment.py similarity index 98% rename from ai_search_with_adi/environment.py rename to ai_search_with_adi/ai_search/environment.py index 7503a68..a17d3a1 100644 --- a/ai_search_with_adi/environment.py +++ b/ai_search_with_adi/ai_search/environment.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """Module providing environment definition""" import os from dotenv import find_dotenv, load_dotenv diff --git a/ai_search_with_adi/inquiry_document.py b/ai_search_with_adi/ai_search/inquiry_document.py similarity index 66% rename from ai_search_with_adi/inquiry_document.py rename to ai_search_with_adi/ai_search/inquiry_document.py index b70251e..36a55a4 100644 --- a/ai_search_with_adi/inquiry_document.py +++ b/ai_search_with_adi/ai_search/inquiry_document.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from azure.search.documents.indexes.models import ( SearchFieldDataType, SearchField, @@ -11,7 +14,6 @@ FieldMapping, IndexingParameters, IndexingParametersConfiguration, - BlobIndexerImageAction, SearchIndexerIndexProjections, SearchIndexerIndexProjectionSelector, SearchIndexerIndexProjectionsParameters, @@ -19,10 +21,9 @@ SimpleField, BlobIndexerDataToExtract, IndexerExecutionEnvironment, - BlobIndexerPDFTextRotationAlgorithm, ) from ai_search import AISearch -from environment import ( +from ai_search_with_adi.ai_search.environment import ( get_search_embedding_model_dimensions, IndexerType, ) @@ -47,9 +48,33 @@ def __init__( else: self.enable_page_by_chunking = False - # explicitly setting it to false no matter what output comes in - # might be removed later - # self.enable_page_by_chunking = False + @property + def index_name(self): + """Get the index name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" + return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-index{self.suffix}" + + @property + def vector_search_profile_name(self): + """Get the vector search profile name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" + return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-compass-vector-search-profile{self.suffix}" + + @property + def vectorizer_name(self): + """Get the vectorizer name. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" + return ( + f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-compass-vectorizer{self.suffix}" + ) + + @property + def algorithm_name(self): + """Gtt the algorithm name. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" + + return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-hnsw-algorithm{self.suffix}" + + @property + def semantic_config_name(self): + """Get the semantic config name for the indexer. Overwritten as this class is subclassed by InquiryDocumentXLSX and they should both point to the same index""" + return f"{str(IndexerType.INQUIRY_DOCUMENT.value)}-semantic-config{self.suffix}" def get_index_fields(self) -> list[SearchableField]: """This function returns the index fields for inquiry document. @@ -60,42 +85,42 @@ def get_index_fields(self) -> list[SearchableField]: fields = [ SimpleField(name="Id", type=SearchFieldDataType.String, filterable=True), SearchableField( - name="Title", type=SearchFieldDataType.String, filterable=True + name="Field1", type=SearchFieldDataType.String, filterable=True ), SearchableField( - name="ID1", + name="Field2", type=SearchFieldDataType.String, sortable=True, filterable=True, facetable=True, ), SearchableField( - name="ID2", + name="Field3", type=SearchFieldDataType.String, sortable=True, filterable=True, facetable=True, ), SearchableField( - name="ChunkId", + name="Field4", type=SearchFieldDataType.String, key=True, - analyzer_name="keyword", + analyzer_name="a1", ), SearchableField( - name="Chunk", + name="Field5", type=SearchFieldDataType.String, sortable=False, filterable=False, facetable=False, ), SearchableField( - name="Section", + name="Field6", type=SearchFieldDataType.String, collection=True, ), SearchField( - name="ChunkEmbedding", + name="EmbeddingField", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), vector_search_dimensions=get_search_embedding_model_dimensions( self.indexer_type @@ -103,17 +128,17 @@ def get_index_fields(self) -> list[SearchableField]: vector_search_profile_name=self.vector_search_profile_name, ), SearchableField( - name="Keywords", type=SearchFieldDataType.String, collection=True + name="Field7", type=SearchFieldDataType.String, collection=True ), SearchableField( - name="SourceUrl", + name="Field8", type=SearchFieldDataType.String, sortable=True, filterable=True, facetable=True, ), SearchableField( - name="AdditionalMetadata", + name="Field9", type=SearchFieldDataType.String, sortable=True, filterable=True, @@ -125,7 +150,7 @@ def get_index_fields(self) -> list[SearchableField]: fields.extend( [ SearchableField( - name="PageNumber", + name="Field10", type=SearchFieldDataType.Int64, sortable=True, filterable=True, @@ -145,13 +170,13 @@ def get_semantic_search(self) -> SemanticSearch: semantic_config = SemanticConfiguration( name=self.semantic_config_name, prioritized_fields=SemanticPrioritizedFields( - title_field=SemanticField(field_name="Title"), - content_fields=[SemanticField(field_name="Chunk")], + title_field=SemanticField(field_name="Field1"), + content_fields=[SemanticField(field_name="Field2")], keywords_fields=[ - SemanticField(field_name="Keywords"), - SemanticField(field_name="Section"), - ], - ), + SemanticField(field_name="Field3"), + SemanticField(field_name="Field4"), + ], + ), ) semantic_search = SemanticSearch(configurations=[semantic_config]) @@ -163,14 +188,17 @@ def get_skills(self): adi_skill = self.get_adi_skill(self.enable_page_by_chunking) + text_split_skill = self.get_text_split_skill( "/document", "/document/extracted_content/content" ) + pre_embedding_cleaner_skill = self.get_pre_embedding_cleaner_skill( - "/document/pages/*", "/document/pages/*", self.enable_page_by_chunking + "/document/pages/*", "/document/pages/*", self.enable_page_by_chunking ) + key_phrase_extraction_skill = self.get_key_phrase_extraction_skill( "/document/pages/*", "/document/pages/*/cleaned_chunk" ) @@ -199,60 +227,44 @@ def get_skills(self): def get_index_projections(self) -> SearchIndexerIndexProjections: """This function returns the index projections for inquiry document.""" - mappings =[ - InputFieldMappingEntry( - name="Chunk", source="/document/pages/*/chunk" - ), - InputFieldMappingEntry( - name="ChunkEmbedding", - source="/document/pages/*/vector", - ), - InputFieldMappingEntry( - name="Title", - source="/document/Title" - ), - InputFieldMappingEntry( - name="ID1", - source="/document/ID1" - ), - InputFieldMappingEntry( - name="ID2", - source="/document/ID2" - ), - InputFieldMappingEntry( - name="SourceUrl", - source="/document/SourceUrl" - ), - InputFieldMappingEntry( - name="Keywords", - source="/document/pages/*/keywords" - ), - InputFieldMappingEntry( - name="AdditionalMetadata", - source="/document/AdditionalMetadata", - ), - InputFieldMappingEntry( - name="Section", - source="/document/pages/*/eachsection" - ) - ] - + mappings = [ + InputFieldMappingEntry(name="Chunk", source="/document/pages/*/chunk"), + InputFieldMappingEntry( + name="ChunkEmbedding", + source="/document/pages/*/vector", + ), + InputFieldMappingEntry(name="Field1", source="/document/Field1"), + InputFieldMappingEntry(name="Field2", source="/document/Field2"), + InputFieldMappingEntry(name="Field3", source="/document/Field3"), + InputFieldMappingEntry(name="Field4", source="/document/Field4"), + InputFieldMappingEntry( + name="Field5", source="/document/pages/*/Field5" + ), + InputFieldMappingEntry( + name="Field6", + source="/document/Field6", + ), + InputFieldMappingEntry( + name="Field7", source="/document/pages/*/Field7" + ), + ] + if self.enable_page_by_chunking: mappings.extend( [ InputFieldMappingEntry( - name="PageNumber", source="/document/pages/*/page_no" - ) - ] + name="Field8", source="/document/pages/*/Field8" + ) + ] ) - + index_projections = SearchIndexerIndexProjections( selectors=[ SearchIndexerIndexProjectionSelector( target_index_name=self.index_name, parent_key_field_name="Id", source_context="/document/pages/*", - mappings=mappings + mappings=mappings, ), ], parameters=SearchIndexerIndexProjectionsParameters( @@ -277,12 +289,9 @@ def get_indexer(self) -> SearchIndexer: indexer_parameters = IndexingParameters( batch_size=batch_size, configuration=IndexingParametersConfiguration( - # image_action=BlobIndexerImageAction.GENERATE_NORMALIZED_IMAGE_PER_PAGE, data_to_extract=BlobIndexerDataToExtract.ALL_METADATA, query_timeout=None, - # allow_skillset_to_read_file_data=True, execution_environment=IndexerExecutionEnvironment.PRIVATE, - # pdf_text_rotation_algorithm=BlobIndexerPDFTextRotationAlgorithm.DETECT_ANGLES, fail_on_unprocessable_document=False, fail_on_unsupported_content_type=False, index_storage_metadata_only_for_oversized_documents=True, @@ -302,16 +311,16 @@ def get_indexer(self) -> SearchIndexer: FieldMapping( source_field_name="metadata_storage_name", target_field_name="Title" ), - FieldMapping(source_field_name="ID1", target_field_name="ID1"), + FieldMapping(source_field_name="Field1", target_field_name="Field1"), FieldMapping( - source_field_name="ID2", target_field_name="ID2" + source_field_name="Field2", target_field_name="Field2" ), FieldMapping( - source_field_name="SharePointUrl", target_field_name="SourceUrl" + source_field_name="Field3", target_field_name="Field3" ), FieldMapping( - source_field_name="Additional_Metadata", - target_field_name="AdditionalMetadata", + source_field_name="Field4", + target_field_name="Field4", ), ], parameters=indexer_parameters, diff --git a/ai_search_with_adi/function_apps/common/ai_search.py b/ai_search_with_adi/function_apps/common/ai_search.py index 1bba829..eedf27e 100644 --- a/ai_search_with_adi/function_apps/common/ai_search.py +++ b/ai_search_with_adi/function_apps/common/ai_search.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from azure.search.documents.indexes.aio import SearchIndexerClient, SearchIndexClient from azure.search.documents.aio import SearchClient from azure.search.documents.indexes.models import SynonymMap @@ -75,7 +78,7 @@ async def trigger_indexer(self, indexer_name): logging.error("Unable to run indexer %s", e) async def search_index( - self, index_name, semantic_config, search_text, deal_id=None + self, index_name, semantic_config, search_text, filter_field=None ): """Search the index using the provided search text.""" async with AsyncAzureOpenAI( @@ -98,8 +101,8 @@ async def search_index( fields="ChunkEmbedding", ) - if deal_id: - filter_expression = f"DealId eq '{deal_id}'" + if filter_field: + filter_expression = f"filter_field eq '{filter_field}'" else: filter_expression = None diff --git a/ai_search_with_adi/function_apps/common/delay_processing_exception.py b/ai_search_with_adi/function_apps/common/delay_processing_exception.py new file mode 100644 index 0000000..a8ef226 --- /dev/null +++ b/ai_search_with_adi/function_apps/common/delay_processing_exception.py @@ -0,0 +1,4 @@ +class DelayProcessingException(Exception): + """Exception to delay processing.""" + + pass diff --git a/ai_search_with_adi/function_apps/common/payloads/error.py b/ai_search_with_adi/function_apps/common/payloads/error.py index 49e456e..5a7f443 100644 --- a/ai_search_with_adi/function_apps/common/payloads/error.py +++ b/ai_search_with_adi/function_apps/common/payloads/error.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone diff --git a/ai_search_with_adi/function_apps/common/payloads/header.py b/ai_search_with_adi/function_apps/common/payloads/header.py index e7a521c..d90e684 100644 --- a/ai_search_with_adi/function_apps/common/payloads/header.py +++ b/ai_search_with_adi/function_apps/common/payloads/header.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone from enum import Enum @@ -15,8 +18,6 @@ class TaskEnum(Enum): PENDING_INDEX_COMPLETION = "pending_index_completion" PENDING_INDEX_TRIGGER = "pending_index_trigger" - PENDING_SUMMARY_GENERATION = "pending_summary_generation" - class Header(BaseModel): """Header model""" diff --git a/ai_search_with_adi/function_apps/common/payloads/payload.py b/ai_search_with_adi/function_apps/common/payloads/payload.py index fb2f4f9..b36f25f 100644 --- a/ai_search_with_adi/function_apps/common/payloads/payload.py +++ b/ai_search_with_adi/function_apps/common/payloads/payload.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pydantic import BaseModel, ConfigDict import logging diff --git a/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py index 8aa0335..caf2ade 100644 --- a/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py +++ b/ai_search_with_adi/function_apps/common/payloads/pending_index_completion.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pydantic import BaseModel, Field, ConfigDict from datetime import datetime, timezone from typing import Optional, List @@ -11,12 +14,12 @@ class PendingIndexCompletionBody(BaseModel): """Body model""" indexer: str = Field(..., description="The indexer to trigger") - deal_id: Optional[int] = Field(None, description="The deal ID") + id_field: Optional[int] = Field(None, description="The ID field") blob_storage_url: Optional[str] = Field( ..., description="The URL to the blob storage" ) - deal_name: Optional[str] = Field( - None, description="The text name for the integer deal ID" + id_name: Optional[str] = Field( + None, description="The text name for the integer ID field" ) business_unit: Optional[str] = Field(None, description="The business unit") indexer_start_time: Optional[datetime] = Field( diff --git a/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py b/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py index 2a519d9..e4fd62b 100644 --- a/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py +++ b/ai_search_with_adi/function_apps/common/payloads/pending_index_trigger.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from pydantic import BaseModel, Field, ConfigDict from typing import Optional, List @@ -10,12 +13,14 @@ class PendingIndexTriggerBody(BaseModel): """Body model""" indexer: str = Field(..., description="The indexer to trigger") - deal_id: Optional[int] = Field(None, description="The deal ID") + ## this field can be defined based on your id field + id_field: Optional[int] = Field(None, description="The ID field") blob_storage_url: str = Field(..., description="The URL to the blob storage") - deal_name: Optional[str] = Field( - None, description="The text name for the integer deal ID" + ## this field can be defined based on your id field + id_name: Optional[str] = Field( + None, description="The text name for the integer ID field" ) - business_unit: Optional[str] = Field(None, description="The business unit") + additional_field: Optional[str] = Field(None, description="Description of additional_field") __config__ = ConfigDict(extra="ignore") diff --git a/ai_search_with_adi/function_apps/common/requirements.txt b/ai_search_with_adi/function_apps/common/requirements.txt new file mode 100644 index 0000000..daa8b89 --- /dev/null +++ b/ai_search_with_adi/function_apps/common/requirements.txt @@ -0,0 +1,11 @@ +azure-storage-blob +azure-servicebus +azure-core +azure-identity +pydantic +pymongo +azure-search +azure-search-documents==11.6.0b4 +openai +aiohttp +motor diff --git a/ai_search_with_adi/function_apps/common/service_bus.py b/ai_search_with_adi/function_apps/common/service_bus.py new file mode 100644 index 0000000..9e95fe8 --- /dev/null +++ b/ai_search_with_adi/function_apps/common/service_bus.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import logging +from datetime import datetime, timezone +from azure.identity.aio import DefaultAzureCredential +from azure.servicebus import ServiceBusMessage +from azure.servicebus.aio import ServiceBusClient + + +class ServiceBusHelper: + def __init__(self): + self._client_id = os.environ["FunctionApp__ClientId"] + + self._endpoint = os.environ["ServiceBusTrigger__fullyQualifiedNamespace"] + + async def get_client(self): + credential = DefaultAzureCredential(managed_identity_client_id=self._client_id) + return ServiceBusClient(self._endpoint, credential) + + async def send_message_to_service_bus_queue( + self, queue, payload, enqueue_time=None, retry=False + ): + # update the header + payload.header.last_processed_timestamp = datetime.now(timezone.utc) + payload.header.task = queue + + if retry: + payload.header.retries_remaining -= 1 + try: + service_bus_client = await self.get_client() + async with service_bus_client: + sender = service_bus_client.get_queue_sender(queue_name=queue.value) + + async with sender: + message = ServiceBusMessage( + body=payload.model_dump_json(), + scheduled_enqueue_time_utc=enqueue_time, + ) + await sender.send_messages(message) + logging.info( + f"Sent a message to the Azure Service Bus queue: {queue}" + ) + except Exception as e: + logging.error(f"Failed to send message to the Azure Service Bus queue: {e}") diff --git a/ai_search_with_adi/function_apps/common/storage_account.py b/ai_search_with_adi/function_apps/common/storage_account.py new file mode 100644 index 0000000..ecb4fea --- /dev/null +++ b/ai_search_with_adi/function_apps/common/storage_account.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +import os +import tempfile +from azure.storage.blob.aio import BlobServiceClient +from azure.identity import DefaultAzureCredential +import urllib + + +class StorageAccountHelper: + def __init__(self) -> None: + self._client_id = os.environ["FunctionApp__ClientId"] + + self._endpoint = os.environ["StorageAccount__Endpoint"] + + async def get_client(self): + credential = DefaultAzureCredential(managed_identity_client_id=self._client_id) + + return BlobServiceClient(account_url=self._endpoint, credential=credential) + + async def add_metadata_to_blob(self, source: str, container: str, metadata) -> None: + """Add metadata to the business glossary blob.""" + + blob = urllib.parse.unquote_plus(source) + + blob_service_client = await self.get_client() + async with blob_service_client: + async with blob_service_client.get_blob_client( + container=container, blob=blob + ) as blob_client: + await blob_client.set_blob_metadata(metadata) + + logging.info("Metadata Added") + + async def download_blob_to_temp_dir( + self, source: str, container: str, target_file_name + ) -> tuple[str, dict]: + """Download the business glossary file from the Azure Blob Storage.""" + + blob = urllib.parse.unquote_plus(source) + + blob_service_client = await self.get_client() + async with blob_service_client: + async with blob_service_client.get_blob_client( + container=container, blob=blob + ) as blob_client: + blob_download = await blob_client.download_blob() + blob_contents = await blob_download.readall() + + blob_properties = await blob_client.get_blob_properties() + + logging.info("Blob Downloaded") + # Get the temporary directory + temp_dir = tempfile.gettempdir() + + # Define the temporary file path + temp_file_path = os.path.join(temp_dir, target_file_name) + + # Write the blob contents to the temporary file + with open(temp_file_path, "wb") as temp_file: + temp_file.write(blob_contents) + + return temp_file_path, blob_properties.metadata + + async def upload_business_glossary_dataframe(self, df: str, sheet: str) -> str: + """Upload the business glossary dataframe to a JSONL file.""" + json_lines = df.to_json(orient="records", lines=True) + + container = os.environ["StorageAccount__BusinessGlossary__Container"] + blob = f"{sheet}.jsonl" + blob_service_client = await self.get_client() + async with blob_service_client: + async with blob_service_client.get_blob_client( + container=container, blob=blob + ) as blob_client: + await blob_client.upload_blob(json_lines, overwrite=True) diff --git a/ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py b/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py similarity index 80% rename from ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py rename to ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py index e0542fb..a477a85 100644 --- a/ai_search_with_adi/function_apps/indexer/adi_2_aisearch.py +++ b/ai_search_with_adi/function_apps/indexer/adi_2_ai_search.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import base64 from azure.core.credentials import AzureKeyCredential from azure.ai.documentintelligence.aio import DocumentIntelligenceClient @@ -13,7 +16,7 @@ from common.storage_account import StorageAccountHelper import concurrent.futures import json - +from openai import AzureOpenAI def crop_image_from_pdf_page(pdf_path, page_number, bounding_box): """ @@ -131,36 +134,77 @@ def update_figure_description(md_content, img_description, idx): return new_md_content -async def understand_image_with_vlm(image_base64): +async def understand_image_with_gptv(image_base64, caption): """ - Sends a base64-encoded image to a VLM (Vision Language Model) endpoint for financial analysis. + Generates a description for an image using the GPT-4V model. - Args: - image_base64 (str): The base64-encoded string representation of the image. + Parameters: + - image_base64 (str): image file. + - caption (str): The caption for the image. Returns: - str: The response from the VLM, which is either a financial analysis or a statement indicating the image is not useful. + - img_description (str): The generated description for the image. """ - # prompt = "Describe the image ONLY IF it is useful for financial analysis. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. " - prompt = "Perform financial analysis of the image ONLY IF the image is of graph, chart, flowchart or table. Otherwise, say 'NOT USEFUL IMAGE' and NOTHING ELSE. " - headers = {"Content-Type": "application/json"} - data = {"prompt": prompt, "image": image_base64} - vlm_endpoint = os.environ["AIServices__VLM__Endpoint"] - async with aiohttp.ClientSession() as session: - async with session.post( - vlm_endpoint, headers=headers, json=data, timeout=30 - ) as response: - response_data = await response.json() - response_text = response_data["response"].split("")[0] - - if ( - "not useful for financial analysis" in response_text - or "NOT USEFUL IMAGE" in response_text - ): - return "Irrelevant Image" + + MAX_TOKENS = 2000 + api_key = os.environ["AzureAI_GPT4V_Key"] + api_version = os.environ["AzureAI__GPT4V_Version"] + deployment_name = os.environ["AzureAI__GPT4V_Deployment"] + api_base = os.environ["AzureAI__GPT4V_APIbase"] + + + client = AzureOpenAI( + api_key=api_key, + api_version=api_version, + base_url=f"{api_base}/openai/deployments/{deployment_name}" + ) + + # We send both image caption and the image body to GPTv for better understanding + if caption != "": + response = client.chat.completions.create( + model=deployment_name, + messages=[ + { "role": "system", "content": "You are a helpful assistant." }, + { "role": "user", "content": [ + { + "type": "text", + "text": f"Describe this image (note: it has image caption: {caption}):" + }, + { + "type": "image_base64", + "image_base64": { + "image": image_base64 + } + } + ] } + ], + max_tokens=MAX_TOKENS + ) + else: - return response_text + response = client.chat.completions.create( + model=deployment_name, + messages=[ + { "role": "system", "content": "You are a helpful assistant." }, + { "role": "user", "content": [ + { + "type": "text", + "text": "Describe this image:" + }, + { + "type": "image_base64", + "image_base64": { + "image": image_base64 + } + } + ] } + ], + max_tokens=MAX_TOKENS + ) + img_description = response.choices[0].message.content + + return img_description def pil_image_to_base64(image, image_format="JPEG"): """ @@ -219,7 +263,7 @@ async def process_figures_from_extracted_content( image_base64 = pil_image_to_base64(cropped_image) - img_description += await understand_image_with_vlm(image_base64) + img_description += await understand_image_with_gptv(image_base64) logging.info(f"\tDescription of figure {idx}: {img_description}") markdown_content = update_figure_description( @@ -385,46 +429,31 @@ async def process_adi_2_ai_search(record: dict, chunk_by_page: bool = False) -> } try: - if chunk_by_page: - markdown_content,page_no = create_page_wise_content(result) - else: - markdown_content = result.content - - # Remove this line when VLM is ready - content_with_figures = markdown_content - - # if chunk_by_page: - # tasks = [ - # process_figures_from_extracted_content( - # temp_file_path, page_content, result.figures, page_number=idx - # ) - # for idx, page_content in enumerate(markdown_content) - # ] - # content_with_figures = await asyncio.gather(*tasks) - # else: - # content_with_figures = await process_figures_from_extracted_content( - # temp_file_path, markdown_content, result.figures - # ) - - # Remove remove_irrelevant_figures=True when VLM is ready if chunk_by_page: cleaned_result = [] + markdown_content,page_no = create_page_wise_content(result) + tasks = [ + process_figures_from_extracted_content( + temp_file_path, page_content, result.figures, page_number=idx + ) + for idx, page_content in enumerate(markdown_content) + ] + content_with_figures = await asyncio.gather(*tasks) with concurrent.futures.ProcessPoolExecutor() as executor: - results = executor.map(clean_adi_markdown,content_with_figures, page_no,[False] * len(content_with_figures)) - - for cleaned_content in results: - cleaned_result.append(cleaned_content) - - # with concurrent.futures.ProcessPoolExecutor() as executor: - # futures = { - # executor.submit( - # clean_adi_markdown, page_content, False - # ): page_content - # for page_content in content_with_figures - # } - # for future in concurrent.futures.as_completed(futures): - # cleaned_result.append(future.result()) + futures = { + executor.submit( + clean_adi_markdown, page_content, False + ): page_content + for page_content in content_with_figures + } + for future in concurrent.futures.as_completed(futures): + cleaned_result.append(future.result()) + else: + markdown_content = result.content + content_with_figures = await process_figures_from_extracted_content( + temp_file_path, markdown_content, result.figures + ) cleaned_result = clean_adi_markdown( content_with_figures, page_no=-1,remove_irrelevant_figures=False ) diff --git a/ai_search_with_adi/function_apps/indexer/function_app.py b/ai_search_with_adi/function_apps/indexer/function_app.py index 12d5d5b..6057ec7 100644 --- a/ai_search_with_adi/function_apps/indexer/function_app.py +++ b/ai_search_with_adi/function_apps/indexer/function_app.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from datetime import datetime, timedelta, timezone import azure.functions as func import logging diff --git a/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py b/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py index c6ab40e..d8c023b 100644 --- a/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py +++ b/ai_search_with_adi/function_apps/indexer/key_phrase_extraction.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import logging import json import os diff --git a/ai_search_with_adi/function_apps/indexer/ocr.py b/ai_search_with_adi/function_apps/indexer/ocr.py new file mode 100644 index 0000000..e179eb1 --- /dev/null +++ b/ai_search_with_adi/function_apps/indexer/ocr.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging +import os +from azure.ai.vision.imageanalysis.aio import ImageAnalysisClient +from azure.ai.vision.imageanalysis.models import VisualFeatures +from azure.core.credentials import AzureKeyCredential + + +async def process_ocr(record: dict) -> dict: + logging.info("Python HTTP trigger function processed a request.") + + try: + url = record["data"]["image"]["url"] + logging.info(f"Request Body: {record}") + except KeyError: + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to extract data with ocr. Pass a valid source in the request body.", + } + ], + "warnings": None, + } + else: + logging.info(f"image url: {url}") + + if url is not None: + try: + client = ImageAnalysisClient( + endpoint=os.environ["AIService__Services__Endpoint"], + credential=AzureKeyCredential( + os.environ["AIService__Services__Key"] + ), + ) + result = await client.analyze_from_url( + image_url=url, visual_features=[VisualFeatures.READ] + ) + logging.info("logging output") + + # Extract text from OCR results + text = " ".join([line.text for line in result.read.blocks[0].lines]) + logging.info(text) + + except KeyError as e: + logging.error(e) + logging.error(f"Failed to authenticate with ocr: {e}") + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to authenticate with Ocr. Check the service credentials exist. {e}", + } + ], + "warnings": None, + } + except Exception as e: + logging.error(e) + logging.error( + f"Failed to analyze the document with Azure Document Intelligence: {e}" + ) + logging.error(e.InnerError) + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to analyze the document with ocr. Check the source and try again. {e}", + } + ], + "warnings": None, + } + else: + return { + "recordId": record["recordId"], + "data": {"text": ""}, + } + + return { + "recordId": record["recordId"], + "data": {"text": text}, + } diff --git a/ai_search_with_adi/function_apps/indexer/pending_index_completion.py b/ai_search_with_adi/function_apps/indexer/pending_index_completion.py index e69de29..3488488 100644 --- a/ai_search_with_adi/function_apps/indexer/pending_index_completion.py +++ b/ai_search_with_adi/function_apps/indexer/pending_index_completion.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from common.ai_search import AISearchHelper, IndexerStatusEnum +from common.service_bus import ServiceBusHelper +from common.payloads.pending_index_completion import PendingIndexCompletionPayload +from common.payloads.pending_index_trigger import PendingIndexTriggerPayload +from common.payloads.header import TaskEnum, DataTypeEnum +from common.payloads.error import Error +from datetime import datetime, timedelta, timezone +from common.delay_processing_exception import DelayProcessingException +import asyncio + + +async def process_pending_index_completion(payload: PendingIndexCompletionPayload): + """Process the pending index completion.""" + ai_search_helper = AISearchHelper() + service_bus_helper = ServiceBusHelper() + + status, indexer_start_time = await ai_search_helper.get_indexer_status( + payload.body.indexer + ) + request_time = payload.header.creation_timestamp + enqueue_time = None + queue = None + messages = [] + retry = False + + if status == IndexerStatusEnum.RETRIGGER and payload.header.retries_remaining > 0: + # Trigger the indexer + await ai_search_helper.trigger_indexer(payload.body.indexer) + + errors = [error_item.model_dump() for error_item in payload.errors] + errors.append( + Error( + code="IndexerNotCompleted", + message="Indexer was was in failed state and required retriggering.", + ) + ) + messages.append( + PendingIndexCompletionPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + ) + queue = TaskEnum.PENDING_INDEX_COMPLETION + minutes = 2 ** (11 - payload.header.retries_remaining) + enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + retry = True + elif status == IndexerStatusEnum.RUNNING and payload.header.retries_remaining > 0: + errors = [error_item.model_dump() for error_item in payload.errors] + errors.append( + Error( + code="IndexerNotCompleted", + message="Indexer was completed not at the time of running.", + ) + ) + messages.append( + PendingIndexCompletionPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + ) + queue = TaskEnum.PENDING_INDEX_COMPLETION + minutes = 2 ** (11 - payload.header.retries_remaining) + enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + retry = True + elif ( + status == IndexerStatusEnum.SUCCESS + and indexer_start_time <= request_time + and payload.header.retries_remaining > 0 + ): + errors = [error_item.model_dump() for error_item in payload.errors] + errors.append( + Error( + code="IndexerNotTriggered", + message="Indexer was not triggered.", + ) + ) + messages.append( + PendingIndexTriggerPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + ) + queue = TaskEnum.PENDING_INDEX_TRIGGER + minutes = 2 ** (11 - payload.header.retries_remaining) + enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + retry = True + else: + raise DelayProcessingException( + "Failed to run trigger due to maximum retries exceeded." + ) + + if queue is not None and len(messages) > 0: + message_tasks = [] + for message in messages: + message_tasks.append( + service_bus_helper.send_message_to_service_bus_queue( + queue, message, enqueue_time=enqueue_time, retry=retry + ) + ) + + await asyncio.gather(*message_tasks) diff --git a/ai_search_with_adi/function_apps/indexer/pending_index_trigger.py b/ai_search_with_adi/function_apps/indexer/pending_index_trigger.py new file mode 100644 index 0000000..f803623 --- /dev/null +++ b/ai_search_with_adi/function_apps/indexer/pending_index_trigger.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from common.ai_search import AISearchHelper, IndexerStatusEnum +from common.service_bus import ServiceBusHelper +from common.payloads.pending_index_trigger import PendingIndexTriggerPayload +from common.payloads.pending_index_completion import PendingIndexCompletionPayload +from common.payloads.header import TaskEnum +from datetime import datetime, timedelta, timezone +from common.delay_processing_exception import DelayProcessingException +from common.payloads.error import Error + + +async def process_pending_index_trigger(payload: PendingIndexTriggerPayload): + """Process the pending index trigger.""" + + ai_search_helper = AISearchHelper() + service_bus_helper = ServiceBusHelper() + + status, indexer_start_time = await ai_search_helper.get_indexer_status( + payload.body.indexer + ) + request_time = payload.header.last_processed_timestamp + enqueue_time = None + queue = None + message = None + retry = False + + if status == IndexerStatusEnum.SUCCESS and indexer_start_time > request_time: + errors = [error_item.model_dump() for error_item in payload.errors] + message = PendingIndexCompletionPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + queue = TaskEnum.PENDING_INDEX_COMPLETION + elif status == IndexerStatusEnum.RETRIGGER or status == IndexerStatusEnum.SUCCESS: + # Trigger the indexer + await ai_search_helper.trigger_indexer(payload.body.indexer) + + errors = [error_item.model_dump() for error_item in payload.errors] + + if status == IndexerStatusEnum.RETRIGGER: + errors.append( + Error( + code="IndexerNotCompleted", + message="Indexer was was in failed state and required retriggering.", + ) + ) + + message = PendingIndexCompletionPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + queue = TaskEnum.PENDING_INDEX_COMPLETION + elif status == IndexerStatusEnum.RUNNING and indexer_start_time > request_time: + errors = [error_item.model_dump() for error_item in payload.errors] + message = PendingIndexCompletionPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + queue = TaskEnum.PENDING_INDEX_COMPLETION + elif ( + status == IndexerStatusEnum.RUNNING + and indexer_start_time <= request_time + and payload.header.retries_remaining > 0 + ): + errors = [error_item.model_dump() for error_item in payload.errors] + errors.append( + Error( + code="IndexerAlreadyRunning", + message="Indexer is already running for an outstanding request.", + ) + ) + message = PendingIndexTriggerPayload( + header=payload.header.model_dump(), + body=payload.body.model_dump(), + errors=errors, + ) + queue = TaskEnum.PENDING_INDEX_TRIGGER + minutes = 2 ** (11 - payload.header.retries_remaining) + enqueue_time = datetime.now(timezone.utc) + timedelta(minutes=minutes) + retry = True + else: + raise DelayProcessingException( + "Failed to run trigger due to maximum retries exceeded." + ) + + if queue is not None: + await service_bus_helper.send_message_to_service_bus_queue( + queue, message, enqueue_time=enqueue_time, retry=retry + ) diff --git a/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py b/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py index 2fdf87a..79cbaae 100644 --- a/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py +++ b/ai_search_with_adi/function_apps/indexer/pre_embedding_cleaner.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + import logging import json import string diff --git a/ai_search_with_adi/function_apps/indexer/text_split.py b/ai_search_with_adi/function_apps/indexer/text_split.py new file mode 100644 index 0000000..8121c70 --- /dev/null +++ b/ai_search_with_adi/function_apps/indexer/text_split.py @@ -0,0 +1,355 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import spacy +import logging +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +import json +from sklearn.metrics.pairwise import cosine_similarity + +nlp = spacy.load("en_core_web_md") + + +class RecursiveCharacterTextSplitter: + def __init__(self, fragment_size=100, division_chars=["\n\n", "\n", " ", ""]): + self.fragment_size = fragment_size + self.division_chars = division_chars + + def split_text(self, text): + return self._recursive_split(text, 0) + + def _recursive_split(self, text, char_idx): + if len(text) <= self.fragment_size or char_idx >= len(self.division_chars): + return [text] + + char = self.division_chars[char_idx] + fragments = text.split(char) + result = [] + current_fragment = "" + + for fragment in fragments: + if len(current_fragment) + len(fragment) + len(char) <= self.fragment_size: + current_fragment += char + fragment + else: + if current_fragment: + result.append(current_fragment) + current_fragment = fragment + + if current_fragment: + result.append(current_fragment) + + if any(len(frag) > self.fragment_size for frag in result): + return self._recursive_split(text, char_idx + 1) + + return result + + +class CharacterTextSplitter: + def __init__(self, fragment_size=100, separator=" "): + self.fragment_size = fragment_size + self.separator = separator + + def split_text(self, text): + fragments = text.split(self.separator) + result = [] + current_fragment = "" + + for fragment in fragments: + if ( + len(current_fragment) + len(fragment) + len(self.separator) + <= self.fragment_size + ): + current_fragment += self.separator + fragment + else: + if current_fragment: + result.append(current_fragment) + current_fragment = fragment + + if current_fragment: + result.append(current_fragment) + + return result + + +class RecursiveTextSplitter: + def __init__(self, fragment_size=100, division_tokens=["\n\n", "\n", " ", ""]): + self.fragment_size = fragment_size + self.division_tokens = division_tokens + + def split_text(self, text): + return self._recursive_split(text, 0) + + def _recursive_split(self, text, token_idx): + if len(text) <= self.fragment_size or token_idx >= len(self.division_tokens): + return [text] + + token = self.division_tokens[token_idx] + fragments = text.split(token) + result = [] + current_fragment = "" + + for fragment in fragments: + if len(current_fragment) + len(fragment) + len(token) <= self.fragment_size: + current_fragment += token + fragment + else: + if current_fragment: + result.append(current_fragment) + current_fragment = fragment + + if current_fragment: + result.append(current_fragment) + + if any(len(frag) > self.fragment_size for frag in result): + return self._recursive_split(text, token_idx + 1) + + return result + + +class SemanticDoubleMergingSplitterNodeParser: + def __init__( + self, + initial_threshold=0.8, + appending_threshold=0.7, + merging_threshold=0.75, + fragment_size=100, + spacy_model="en_core_web_md", + ): + self.initial_threshold = initial_threshold + self.appending_threshold = appending_threshold + self.merging_threshold = merging_threshold + self.fragment_size = fragment_size + try: + self.nlp = spacy.load(spacy_model) + except IOError: + raise ValueError( + f"Spacy model '{spacy_model}' not found. Please download it using 'python -m spacy download {spacy_model}'" + ) + + def split_text(self, text): + sentences = self._split_into_sentences(text) + initial_chunks = self._initial_pass(sentences) + final_chunks = self._second_pass(initial_chunks) + return final_chunks + + def _split_into_sentences(self, text): + doc = self.nlp(text) + sentences = [sent.text for sent in doc.sents] + return sentences + + def _initial_pass(self, sentences): + chunks = [] + current_chunk = [] + + i = 0 + while i < len(sentences): + current_chunk.append(sentences[i]) + if len(current_chunk) >= 2: + cosine_sim = self._cosine_similarity( + " ".join(current_chunk[-2:]), sentences[i] + ) + if ( + cosine_sim < self.initial_threshold + or len(" ".join(current_chunk)) > self.fragment_size + ): + if len(current_chunk) > 2: + chunks.append(" ".join(current_chunk[:-1])) + current_chunk = [current_chunk[-1]] + else: + chunks.append(current_chunk[0]) + current_chunk = [current_chunk[1]] + i += 1 + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + def _second_pass(self, chunks): + merged_chunks = [] + current_chunk = chunks[0] + + i = 1 + while i < len(chunks): + cosine_sim = self._cosine_similarity(current_chunk, chunks[i]) + if ( + cosine_sim >= self.merging_threshold + and len(current_chunk + " " + chunks[i]) <= self.fragment_size + ): + current_chunk += " " + chunks[i] + else: + merged_chunks.append(current_chunk) + current_chunk = chunks[i] + i += 1 + + merged_chunks.append(current_chunk) + return merged_chunks + + def _cosine_similarity(self, text1, text2): + vec1 = self.nlp(text1).vector + vec2 = self.nlp(text2).vector + return cosine_similarity([vec1], [vec2])[0, 0] + + +class FlanT5Chunker: + def __init__( + self, model_name="chentong00/propositionizer-wiki-flan-t5-large", device="cpu" + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + self.device = device + self.max_length = 512 # Model's maximum token length + + def flan_t5_chunking(self, text, chunk_size=500, stride=20): + input_text = f"Title: . Section: . Content: {text}" + input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to( + self.device + ) + total_length = input_ids.shape[1] + + chunks = [] + for i in range(0, total_length, chunk_size - stride): + end = min(i + chunk_size, total_length) + chunk_input_ids = input_ids[:, i:end] + outputs = self.model.generate( + chunk_input_ids, max_new_tokens=self.max_length + ).cpu() + output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + try: + prop_list = json.loads(output_text) + except json.JSONDecodeError: + prop_list = [] + print("[ERROR] Failed to parse output text as JSON.") + chunks.append(prop_list) + + # Flatten the list of lists + return [item for sublist in chunks for item in sublist] + + +def clean_input(value): + """Clean the input value. + + Args: + value: The input value. + + Returns: + The cleaned value.""" + if isinstance(value, str): + return value.strip('"') + return value + + +async def process_text_split(record: dict, text_split_config: dict) -> dict: + """Process the text split request. + + Args: + record (dict): The request record. + text_split_config (dict): The headers for config. + + Returns: + dict: The response record. + """ + try: + data = record["data"] + text = clean_input(data.get("text")) + logging.info(f"Request Body: {record}") + except KeyError: + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to split text. Pass valid parameters.", + } + ], + "warnings": None, + } + else: + if text is None: + logging.error("Failed to split text. Pass valid text.") + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to split text. Pass valid text.", + } + ], + "warnings": None, + } + + splitter_type = clean_input( + text_split_config.get("text_split_mode", "recursive_character") + ) + fragment_size = float( + clean_input(text_split_config.get("maximum_page_length", 100)) + ) + separator = clean_input(text_split_config.get("separator", " ")) + initial_threshold = float( + clean_input(text_split_config.get("initial_threshold", 0.8)) + ) + appending_threshold = float( + clean_input(text_split_config.get("appending_threshold", 0.7)) + ) + merging_threshold = float( + clean_input(text_split_config.get("merging_threshold", 0.75)) + ) + + try: + if splitter_type == "recursive_character": + splitter = RecursiveCharacterTextSplitter(fragment_size=fragment_size) + elif splitter_type == "character": + splitter = CharacterTextSplitter( + fragment_size=fragment_size, separator=separator + ) + elif splitter_type == "recursive": + splitter = RecursiveTextSplitter(fragment_size=fragment_size) + elif splitter_type == "semantic": + splitter = SemanticDoubleMergingSplitterNodeParser( + initial_threshold=initial_threshold, + appending_threshold=appending_threshold, + merging_threshold=merging_threshold, + fragment_size=fragment_size, + ) + elif splitter_type == "flan_t5": + splitter = FlanT5Chunker() + else: + logging.error("Failed to split text. Pass valid splitter type.") + logging.error(f"Splitter Type: {splitter_type}") + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": "Failed to split text. Pass valid splitter type.", + } + ], + "warnings": None, + } + + if splitter_type == "flan_t5": + chunks = splitter.flan_t5_chunking(text) + else: + chunks = splitter.split_text(text) + except Exception as e: + logging.error(f"Error during splitting: {e}") + + return { + "recordId": record["recordId"], + "data": {}, + "errors": [ + { + "message": f"Failed to split text. Check function app logs for more details of exact failure. {str(e)}", + } + ], + "warnings": None, + } + + else: + return { + "recordId": record["recordId"], + "data": { + "chunks": chunks, + }, + "errors": None, + "warnings": None, + }