Skip to content

Commit

Permalink
Add Llama-index Azure OpenAI tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Dec 6, 2023
1 parent b5ca84b commit 67c6e6c
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion ragstack-e2e-tests/e2e_tests/llama_index/test_llama_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from abc import ABC, abstractmethod

import pytest
from e2e_tests.conftest import (
AstraRef,
get_required_env,
Expand All @@ -15,7 +17,8 @@
Document,
OpenAIEmbedding,
)
from llama_index.llms import OpenAI
from llama_index.embeddings import AzureOpenAIEmbedding
from llama_index.llms import OpenAI, AzureOpenAI
from llama_index.vector_stores import AstraDBVectorStore


Expand Down Expand Up @@ -82,13 +85,52 @@ def __enter__(self):
return OpenAIEmbedding(api_key=key)


class AzureOpenAILLMContext(ContextMixin):
name = "openai-azure"

def __enter__(self):
return AzureOpenAI(
azure_deployment=get_required_env("AZURE_OPEN_AI_CHAT_MODEL_DEPLOYMENT"),
azure_endpoint=get_required_env("AZURE_OPEN_AI_ENDPOINT"),
api_key=get_required_env("AZURE_OPEN_AI_KEY"),
api_version="2023-07-01-preview",
)


class AzureOpenAIEmbeddingsContext(ContextMixin):
name = "openai-azure"

def __enter__(self):
model_and_deployment = get_required_env(
"AZURE_OPEN_AI_EMBEDDINGS_MODEL_DEPLOYMENT"
)
return AzureOpenAIEmbedding(
model=model_and_deployment,
deployment_name=model_and_deployment,
api_key=get_required_env("AZURE_OPEN_AI_KEY"),
azure_endpoint=get_required_env("AZURE_OPEN_AI_ENDPOINT"),
api_version="2023-05-15",
embed_batch_size=1,
)


def test_openai():
_run_test(
ProdAstraDBVectorStoreContext(), OpenAILLMContext(), OpenAIEmbeddingsContext()
)


@pytest.mark.parametrize(
"vector_store", [ProdAstraDBVectorStoreContext(), DevAstraDBVectorStoreContext()]
)
def test_openai_azure(vector_store):
_run_test(vector_store, AzureOpenAILLMContext(), AzureOpenAIEmbeddingsContext())


def _run_test(vector_store_ctx, llm_ctx, embed_model_ctx):
os.environ[
"RAGSTACK_E2E_TESTS_TEST_INFO"
] = f"llama_index_retrieve::{llm_ctx.name},{embed_model_ctx.name},{vector_store_ctx.name}"
with vector_store_ctx as vector_store, llm_ctx as llm, embed_model_ctx as embed_model:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
Expand Down

0 comments on commit 67c6e6c

Please sign in to comment.