diff --git a/ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py b/ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py index a669d324d..286b01722 100644 --- a/ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py +++ b/ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py @@ -1,5 +1,7 @@ +import os from abc import ABC, abstractmethod +import pytest from e2e_tests.conftest import ( AstraRef, get_required_env, @@ -15,7 +17,8 @@ Document, OpenAIEmbedding, ) -from llama_index.llms import OpenAI +from llama_index.embeddings import AzureOpenAIEmbedding +from llama_index.llms import OpenAI, AzureOpenAI from llama_index.vector_stores import AstraDBVectorStore @@ -82,13 +85,52 @@ def __enter__(self): return OpenAIEmbedding(api_key=key) +class AzureOpenAILLMContext(ContextMixin): + name = "openai-azure" + + def __enter__(self): + return AzureOpenAI( + azure_deployment=get_required_env("AZURE_OPEN_AI_CHAT_MODEL_DEPLOYMENT"), + azure_endpoint=get_required_env("AZURE_OPEN_AI_ENDPOINT"), + api_key=get_required_env("AZURE_OPEN_AI_KEY"), + api_version="2023-07-01-preview", + ) + + +class AzureOpenAIEmbeddingsContext(ContextMixin): + name = "openai-azure" + + def __enter__(self): + model_and_deployment = get_required_env( + "AZURE_OPEN_AI_EMBEDDINGS_MODEL_DEPLOYMENT" + ) + return AzureOpenAIEmbedding( + model=model_and_deployment, + deployment_name=model_and_deployment, + api_key=get_required_env("AZURE_OPEN_AI_KEY"), + azure_endpoint=get_required_env("AZURE_OPEN_AI_ENDPOINT"), + api_version="2023-05-15", + embed_batch_size=1, + ) + + def test_openai(): _run_test( ProdAstraDBVectorStoreContext(), OpenAILLMContext(), OpenAIEmbeddingsContext() ) +@pytest.mark.parametrize( + "vector_store", [ProdAstraDBVectorStoreContext(), DevAstraDBVectorStoreContext()] +) +def test_openai_azure(vector_store): + _run_test(vector_store, AzureOpenAILLMContext(), AzureOpenAIEmbeddingsContext()) + + def _run_test(vector_store_ctx, llm_ctx, embed_model_ctx): + os.environ[ + "RAGSTACK_E2E_TESTS_TEST_INFO" + ] = f"llama_index_retrieve::{llm_ctx.name},{embed_model_ctx.name},{vector_store_ctx.name}" with vector_store_ctx as vector_store, llm_ctx as llm, embed_model_ctx as embed_model: storage_context = StorageContext.from_defaults(vector_store=vector_store) service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)