Skip to content

Commit

Permalink
Update the deployment script
Browse files Browse the repository at this point in the history
  • Loading branch information
BenConstable9 committed Sep 10, 2024
1 parent ad8684f commit 42adc2a
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 81 deletions.
88 changes: 50 additions & 38 deletions ai_search_with_adi/ai_search/ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
SearchIndexerDataContainer,
SearchIndexerDataSourceConnection,
SearchIndexerDataSourceType,
SearchIndexerDataUserAssignedIdentity,
OutputFieldMappingEntry,
InputFieldMappingEntry,
SynonymMap,
Expand All @@ -29,29 +28,21 @@
)
from azure.core.exceptions import HttpResponseError
from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient
from ai_search_with_adi.ai_search.environment import (
get_fq_blob_connection_string,
get_blob_container_name,
get_custom_skill_function_url,
get_managed_identity_fqname,
get_function_app_authresourceid,
)
from ai_search_with_adi.ai_search.environment import AISearchEnvironment, IdentityType


class AISearch(ABC):
"""Handles the deployment of the AI search pipeline."""

def __init__(
self,
endpoint: str,
credential,
suffix: str | None = None,
rebuild: bool | None = False,
):
"""Initialize the AI search class
Args:
endpoint (str): The search endpoint
credential (AzureKeyCredential): The search credential
suffix (str, optional): The suffix for the indexer. Defaults to None.
suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer.
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
"""
self.indexer_type = None
Expand All @@ -61,15 +52,22 @@ def __init__(
else:
self.rebuild = False

# If suffix is None, then it is not a test indexer. Test indexer limits the rate of indexing and turns off the schedule. Useful for testing index changes
if suffix is None:
self.suffix = ""
self.test = False
else:
self.suffix = f"-{suffix}-test"
self.test = True

self._search_indexer_client = SearchIndexerClient(endpoint, credential)
self._search_index_client = SearchIndexClient(endpoint, credential)
self.environment = AISearchEnvironment(indexer_type=self.indexer_type)

self._search_indexer_client = SearchIndexerClient(
self.environment.ai_search_endpoint, self.environment.ai_search_credential
)
self._search_index_client = SearchIndexClient(
self.environment.ai_search_endpoint, self.environment.ai_search_credential
)

@property
def indexer_name(self):
Expand All @@ -94,7 +92,7 @@ def index_name(self):
@property
def data_source_name(self):
"""Get the data source name for the indexer."""
blob_container_name = get_blob_container_name(self.indexer_type)
blob_container_name = self.environment.get_blob_container_name()
return f"{blob_container_name}-data-source{self.suffix}"

@property
Expand Down Expand Up @@ -146,16 +144,6 @@ def get_synonym_map_names(self) -> list[str]:
"""Get the synonym map names for the indexer."""
return []

def get_user_assigned_managed_identity(
self,
) -> SearchIndexerDataUserAssignedIdentity:
"""Get user assigned managed identity details"""

user_assigned_identity = SearchIndexerDataUserAssignedIdentity(
user_assigned_identity=get_managed_identity_fqname()
)
return user_assigned_identity

def get_data_source(self) -> SearchIndexerDataSourceConnection:
"""Get the data source for the indexer."""

Expand All @@ -166,19 +154,21 @@ def get_data_source(self) -> SearchIndexerDataSourceConnection:
)

container = SearchIndexerDataContainer(
name=get_blob_container_name(self.indexer_type)
name=self.environment.get_blob_container_name()
)

data_source_connection = SearchIndexerDataSourceConnection(
name=self.data_source_name,
type=SearchIndexerDataSourceType.AZURE_BLOB,
connection_string=get_fq_blob_connection_string(),
connection_string=self.environment.blob_connection_string,
container=container,
data_change_detection_policy=data_change_detection_policy,
data_deletion_detection_policy=data_deletion_detection_policy,
identity=self.get_user_assigned_managed_identity(),
)

if self.environment.identity_type != IdentityType.KEY:
data_source_connection.identity = self.environment.ai_search_identity_id

return data_source_connection

def get_pre_embedding_cleaner_skill(
Expand Down Expand Up @@ -226,17 +216,25 @@ def get_pre_embedding_cleaner_skill(
name="Pre Embedding Cleaner Skill",
description="Skill to clean the data before sending to embedding",
context=context,
uri=get_custom_skill_function_url("pre_embedding_cleaner"),
uri=self.environment.get_custom_skill_function_url("pre_embedding_cleaner"),
timeout="PT230S",
batch_size=batch_size,
degree_of_parallelism=degree_of_parallelism,
http_method="POST",
inputs=pre_embedding_cleaner_skill_inputs,
outputs=pre_embedding_cleaner_skill_outputs,
auth_resource_id=get_function_app_authresourceid(),
auth_identity=self.get_user_assigned_managed_identity(),
)

if self.environment.identity_type != IdentityType.KEY:
pre_embedding_cleaner_skill.auth_identity = (
self.environment.ai_search_identity_id
)

if self.environment.identity_type == IdentityType.USER_ASSIGNED:
pre_embedding_cleaner_skill.auth_resource_id = (
self.environment.ai_search_user_assigned_identity
)

return pre_embedding_cleaner_skill

def get_text_split_skill(self, context, source) -> SplitSkill:
Expand Down Expand Up @@ -294,7 +292,7 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
name="ADI Skill",
description="Skill to generate ADI",
context="/document",
uri=get_custom_skill_function_url("adi"),
uri=self.environment.get_custom_skill_function_url("adi"),
timeout="PT230S",
batch_size=batch_size,
degree_of_parallelism=degree_of_parallelism,
Expand All @@ -306,10 +304,16 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
)
],
outputs=output,
auth_resource_id=get_function_app_authresourceid(),
auth_identity=self.get_user_assigned_managed_identity(),
)

if self.environment.identity_type != IdentityType.KEY:
adi_skill.auth_identity = self.environment.ai_search_identity_id

if self.environment.identity_type == IdentityType.USER_ASSIGNED:
adi_skill.auth_resource_id = (
self.environment.ai_search_user_assigned_identity
)

return adi_skill

def get_vector_skill(
Expand Down Expand Up @@ -368,17 +372,25 @@ def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
name="Key phrase extraction API",
description="Skill to extract keyphrases",
context=context,
uri=get_custom_skill_function_url("keyphraseextraction"),
uri=self.environment.get_custom_skill_function_url("key_phrase_extraction"),
timeout="PT230S",
batch_size=batch_size,
degree_of_parallelism=degree_of_parallelism,
http_method="POST",
inputs=keyphrase_extraction_skill_inputs,
outputs=keyphrase_extraction__skill_outputs,
auth_resource_id=get_function_app_authresourceid(),
auth_identity=self.get_user_assigned_managed_identity(),
)

if self.environment.identity_type != IdentityType.KEY:
key_phrase_extraction_skill.auth_identity = (
self.environment.ai_search_identity_id
)

if self.environment.identity_type == IdentityType.USER_ASSIGNED:
key_phrase_extraction_skill.auth_resource_id = (
self.environment.ai_search_user_assigned_identity
)

return key_phrase_extraction_skill

def get_vector_search(self) -> VectorSearch:
Expand Down
43 changes: 10 additions & 33 deletions ai_search_with_adi/ai_search/deploy.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from ai_search_with_adi.ai_search.environment import (
get_search_endpoint,
get_managed_identity_id,
get_search_key,
get_key_vault_url,
)
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from ai_search_with_adi.ai_search.rag_documents import RagDocumentsAISearch


def main(args):
endpoint = get_search_endpoint()

try:
credential = DefaultAzureCredential(
managed_identity_client_id=get_managed_identity_id()
)
# initializing key vault client
client = SecretClient(vault_url=get_key_vault_url(), credential=credential)
print("Using managed identity credential")
except Exception as e:
print(e)
credential = AzureKeyCredential(get_search_key(client=client))
print("Using Azure Key credential")
def deploy_config(arguments: argparse.Namespace):
"""Deploy the indexer configuration based on the arguments passed.
if args.indexer_type == "rag":
# Deploy the inquiry index
Args:
arguments (argparse.Namespace): The arguments passed to the script"""
if arguments.indexer_type == "rag":
index_config = RagDocumentsAISearch(
endpoint=endpoint,
credential=credential,
suffix=args.suffix,
rebuild=args.rebuild,
enable_page_by_chunking=args.enable_page_chunking,
suffix=arguments.suffix,
rebuild=arguments.rebuild,
enable_page_by_chunking=arguments.enable_page_chunking,
)
else:
raise ValueError("Invalid Indexer Type")

index_config.deploy()

if args.rebuild:
if arguments.rebuild:
index_config.reset_indexer()


Expand Down Expand Up @@ -75,4 +52,4 @@ def main(args):
)

args = parser.parse_args()
main(args)
deploy_config(args)
28 changes: 18 additions & 10 deletions ai_search_with_adi/ai_search/rag_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from ai_search import AISearch
from ai_search_with_adi.ai_search.environment import (
get_search_embedding_model_dimensions,
IndexerType,
)

Expand All @@ -34,13 +33,17 @@ class RagDocumentsAISearch(AISearch):

def __init__(
self,
endpoint,
credential,
suffix=None,
rebuild=False,
suffix: str | None = None,
rebuild: bool | None = False,
enable_page_by_chunking=False,
):
super().__init__(endpoint, credential, suffix, rebuild)
"""Initialize the RagDocumentsAISearch class. This class implements the deployment of the rag document index.
Args:
suffix (str, optional): The suffix for the indexer. Defaults to None. If an suffix is provided, it is assumed to be a test indexer.
rebuild (bool, optional): Whether to rebuild the index. Defaults to False.
"""
super().__init__(suffix, rebuild)

self.indexer_type = IndexerType.RAG_DOCUMENTS
if enable_page_by_chunking is not None:
Expand Down Expand Up @@ -80,9 +83,7 @@ def get_index_fields(self) -> list[SearchableField]:
SearchField(
name="ChunkEmbedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
vector_search_dimensions=get_search_embedding_model_dimensions(
self.indexer_type
),
vector_search_dimensions=self.environment.embedding_model_dimensions,
vector_search_profile_name=self.vector_search_profile_name,
),
SearchableField(
Expand Down Expand Up @@ -224,19 +225,26 @@ def get_indexer(self) -> SearchIndexer:
Returns:
SearchIndexer: The indexer for inquiry document"""

# Only place on schedule if it is not a test deployment
if self.test:
schedule = None
batch_size = 4
else:
schedule = {"interval": "PT15M"}
batch_size = 16

if self.environment.use_private_endpoint:
execution_environment = IndexerExecutionEnvironment.PRIVATE
else:
execution_environment = IndexerExecutionEnvironment.STANDARD

indexer_parameters = IndexingParameters(
batch_size=batch_size,
configuration=IndexingParametersConfiguration(
data_to_extract=BlobIndexerDataToExtract.ALL_METADATA,
query_timeout=None,
execution_environment=IndexerExecutionEnvironment.PRIVATE,
execution_environment=execution_environment,
fail_on_unprocessable_document=False,
fail_on_unsupported_content_type=False,
index_storage_metadata_only_for_oversized_documents=True,
Expand Down

0 comments on commit 42adc2a

Please sign in to comment.