Skip to content

Commit

Permalink
adi and indexer changes
Browse files Browse the repository at this point in the history
  • Loading branch information
priyal1508 committed Sep 9, 2024
1 parent b4b1409 commit e9a0b8e
Show file tree
Hide file tree
Showing 22 changed files with 1,087 additions and 231 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from abc import ABC, abstractmethod
from azure.search.documents.indexes.models import (
SearchIndex,
Expand Down Expand Up @@ -28,7 +31,7 @@
)
from azure.core.exceptions import HttpResponseError
from azure.search.documents.indexes import SearchIndexerClient, SearchIndexClient
from environment import (
from ai_search_with_adi.ai_search.environment import (
get_fq_blob_connection_string,
get_blob_container_name,
get_custom_skill_function_url,
Expand Down Expand Up @@ -70,31 +73,48 @@ def __init__(

@property
def indexer_name(self):
"""Get the indexer name for the indexer."""
return f"{str(self.indexer_type.value)}-indexer{self.suffix}"

@property
def skillset_name(self):
"""Get the skillset name for the indexer."""
return f"{str(self.indexer_type.value)}-skillset{self.suffix}"

@property
def semantic_config_name(self):
"""Get the semantic config name for the indexer."""
return f"{str(self.indexer_type.value)}-semantic-config{self.suffix}"

@property
def index_name(self):
"""Get the index name for the indexer."""
return f"{str(self.indexer_type.value)}-index{self.suffix}"

@property
def data_source_name(self):
"""Get the data source name for the indexer."""
blob_container_name = get_blob_container_name(self.indexer_type)
return f"{blob_container_name}-data-source{self.suffix}"

@property
def vector_search_profile_name(self):
"""Get the vector search profile name for the indexer."""
return (
f"{str(self.indexer_type.value)}-compass-vector-search-profile{self.suffix}"
)

@property
def vectorizer_name(self):
"""Get the vectorizer name."""
return f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}"

@property
def algorithm_name(self):
"""Gtt the algorithm name"""

return f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}"

@abstractmethod
def get_index_fields(self) -> list[SearchableField]:
"""Get the index fields for the indexer.
Expand Down Expand Up @@ -122,6 +142,7 @@ def get_index_projections(self):
return None

def get_synonym_map_names(self):
"""Get the synonym map names for the indexer."""
return []

def get_user_assigned_managed_identity(
Expand Down Expand Up @@ -292,67 +313,7 @@ def get_text_split_skill(self, context, source) -> SplitSkill:

return text_split_skill

def get_custom_text_split_skill(
self,
context,
source,
text_split_mode="semantic",
maximum_page_length=1000,
separator=" ",
initial_threshold=0.7,
appending_threshold=0.6,
merging_threshold=0.6,
) -> WebApiSkill:
"""Get the custom skill for text split.
Args:
-----
context (str): The context of the skill
inputs (List[InputFieldMappingEntry]): The inputs of the skill
outputs (List[OutputFieldMappingEntry]): The outputs of the skill
Returns:
--------
WebApiSkill: The custom skill for text split"""

if self.test:
batch_size = 2
degree_of_parallelism = 2
else:
batch_size = 2
degree_of_parallelism = 6

text_split_skill_inputs = [
InputFieldMappingEntry(name="text", source=source),
]

headers = {
"text_split_mode": text_split_mode,
"maximum_page_length": maximum_page_length,
"separator": separator,
"initial_threshold": initial_threshold,
"appending_threshold": appending_threshold,
"merging_threshold": merging_threshold,
}

text_split_skill = WebApiSkill(
name="Text Split Skill",
description="Skill to split the text before sending to embedding",
context=context,
uri=get_custom_skill_function_url("split"),
timeout="PT230S",
batch_size=batch_size,
degree_of_parallelism=degree_of_parallelism,
http_method="POST",
http_headers=headers,
inputs=text_split_skill_inputs,
outputs=[OutputFieldMappingEntry(name="chunks", target_name="pages")],
auth_resource_id=get_function_app_authresourceid(),
auth_identity=self.get_user_assigned_managed_identity(),
)

return text_split_skill


def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:
"""Get the custom skill for adi.
Expand Down Expand Up @@ -400,6 +361,46 @@ def get_adi_skill(self, chunk_by_page=False) -> WebApiSkill:

return adi_skill

def get_excel_skill(self) -> WebApiSkill:
"""Get the custom skill for adi.
Returns:
--------
WebApiSkill: The custom skill for adi"""

if self.test:
batch_size = 1
degree_of_parallelism = 4
else:
batch_size = 1
degree_of_parallelism = 8

output = [
OutputFieldMappingEntry(name="extracted_content", target_name="pages")
]

xlsx_skill = WebApiSkill(
name="XLSX Skill",
description="Skill to generate Markdown from XLSX",
context="/document",
uri=get_custom_skill_function_url("xlsx"),
timeout="PT230S",
batch_size=batch_size,
degree_of_parallelism=degree_of_parallelism,
http_method="POST",
http_headers={},
inputs=[
InputFieldMappingEntry(
name="source", source="/document/metadata_storage_path"
)
],
outputs=output,
auth_resource_id=get_function_app_authresourceid(),
auth_identity=self.get_user_assigned_managed_identity(),
)

return xlsx_skill

def get_key_phrase_extraction_skill(self, context, source) -> WebApiSkill:
"""Get the key phrase extraction skill.
Expand Down Expand Up @@ -570,25 +571,21 @@ def get_compass_vector_search(self) -> VectorSearch:
Returns:
VectorSearch: The vector search configuration
"""
vectorizer_name = (
f"{str(self.indexer_type.value)}-compass-vectorizer{self.suffix}"
)
algorithim_name = f"{str(self.indexer_type.value)}-hnsw-algorithm{self.suffix}"

vector_search = VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(name=algorithim_name),
HnswAlgorithmConfiguration(name=self.algorithm_name),
],
profiles=[
VectorSearchProfile(
name=self.vector_search_profile_name,
algorithm_configuration_name=algorithim_name,
vectorizer=vectorizer_name,
algorithm_configuration_name=self.algorithm_name,
vectorizer=self.vectorizer_name,
)
],
vectorizers=[
CustomVectorizer(
name=vectorizer_name,
name=self.vectorizer_name,
custom_web_api_parameters=CustomWebApiParameters(
uri=get_custom_skill_function_url("compass"),
auth_resource_id=get_function_app_authresourceid(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from environment import get_search_endpoint, get_managed_identity_id, get_search_key,get_key_vault_url
from ai_search_with_adi.ai_search.environment import (
get_search_endpoint,
get_managed_identity_id,
get_search_key,
get_key_vault_url,
)
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential,ManagedIdentityCredential,EnvironmentCredential
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from inquiry_document import InquiryDocumentAISearch


def main(args):
endpoint = get_search_endpoint()

try:
credential = DefaultAzureCredential(managed_identity_client_id =get_managed_identity_id())
credential = DefaultAzureCredential(
managed_identity_client_id=get_managed_identity_id()
)
# initializing key vault client
client = SecretClient(vault_url=get_key_vault_url(), credential=credential)
print("Using managed identity credential")
except Exception as e:
print(e)
credential = (
AzureKeyCredential(get_search_key(client=client))
)
credential = AzureKeyCredential(get_search_key(client=client))
print("Using Azure Key credential")

if args.indexer_type == "inquiry":
# Deploy the inquiry index
index_config = InquiryDocumentAISearch(
endpoint=endpoint,
credential=credential,
endpoint=endpoint,
credential=credential,
suffix=args.suffix,
rebuild=args.rebuild,
enable_page_by_chunking=args.enable_page_chunking
rebuild=args.rebuild,
enable_page_by_chunking=args.enable_page_chunking,
)
else:
raise ValueError("Invalid Indexer Type")

index_config.deploy()

if args.rebuild:
Expand All @@ -42,7 +52,7 @@ def main(args):
"--indexer_type",
type=str,
required=True,
help="Type of Indexer want to deploy. inquiry/summary/glossary",
help="Type of Indexer want to deploy.",
)
parser.add_argument(
"--rebuild",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Module providing environment definition"""
import os
from dotenv import find_dotenv, load_dotenv
Expand Down
Loading

0 comments on commit e9a0b8e

Please sign in to comment.